package mpqareader; import se.lth.cs.nlp.nlputils.core.*; import se.lth.cs.nlp.nlputils.ml_long.*; import se.lth.cs.nlp.nlputils.depgraph.*; import se.lth.cs.nlp.nlputils.framenet.*; import se.lth.cs.nlp.depsrl.format.*; import srlpostprocess.*; // build path cycle ///import tkreranker.Linearizer; import java.io.*; import java.util.*; public class BOWSubjSentClassifier { static final boolean SET = false; static final boolean NORMALIZE = true; //true; //true; private static final boolean REMOVE_STOP_WORDS = false; private static final boolean DUMP_SVMLIGHT = false; private static final String SVMLIGHT_FILE = "svmlight.dump"; private static final boolean USE_CHAR_TRIGRAMS = false; private static final boolean USE_FRAMES = false; private static final boolean USE_WB_FEATURES = false; private static final boolean REPLACE_LEMMAS = false; private static final boolean USE_UNIGRAMS = true; private static final boolean USE_POS = false; private static final boolean USE_BIGRAMS = false; private static final boolean USE_POS_BIGRAMS = false; private static final boolean USE_PRED_SENSES = false; private static final boolean USE_VO_TUPLES = false; private static final boolean USE_GFS = false; private static final boolean USE_PA_TUPLES = false; private static final boolean USE_RAW_PA_TUPLES = false; // private static final boolean USE_UNIGRAMS = true; // private static final boolean USE_POS = true; // private static final boolean USE_BIGRAMS = true; // private static final boolean USE_POS_BIGRAMS = true; // private static final boolean USE_PRED_SENSES = true; // private static final boolean USE_VO_TUPLES = true; // private static final boolean USE_GFS = true; // private static final boolean USE_PA_TUPLES = true; // private static final boolean USE_RAW_PA_TUPLES = false; private static final boolean USE_PREV_CLASS = false; static final boolean USE_SUBJLEX = true; //true; private static final boolean USE_JUSSI_FEATURES = false; // private static final boolean USE_UNIGRAMS = true; // private static final boolean USE_POS = false; // private static final boolean USE_BIGRAMS = false; // private static final boolean USE_POS_BIGRAMS = false; // private static final boolean USE_PRED_SENSES = false; // private static final boolean USE_VO_TUPLES = false; // private static final boolean USE_GFS = false; // private static final boolean USE_PA_TUPLES = false; // private static final boolean USE_RAW_PA_TUPLES = false; //private static boolean FIXED_TESTSET = true; private static ArrayList readSentence(BufferedReader br) throws IOException { ArrayList out = new ArrayList(); String line = br.readLine().trim(); while(!line.isEmpty()) { out.add(line); line = br.readLine().trim(); } return out; } private static final String[] STOP_WORDS = new String[] { "a", "an", "and", "are", "as", "at", "be", "by", "for", "from", "has", "he", "in", "is", "it", "its", "of", "on", "that", "the", "to", "was", "were", "will", "with", }; private static HashSet STOP_WORD_SET = new HashSet(); static { STOP_WORD_SET.addAll(Arrays.asList(STOP_WORDS)); } static SparseVector representSentence(ArrayList sen, Triple> g, LexicalDB propBank, LexicalDB nomBank, FrameNet framenet, boolean previousWasSubjective, SubjectivityLexicon subjLex, SymbolEncoder enc) throws IOException { for(int i = 0; i < sen.size(); i++) { String s = sen.get(i); s = s.toLowerCase(); s = s.replaceAll("[1-9]", "0"); sen.set(i, s); } SparseVector sv = new SparseVector(); if(USE_PREV_CLASS) { sv.put(enc.encode("prev_" + previousWasSubjective), 1.0); } if(USE_CHAR_TRIGRAMS) { StringBuilder sb = new StringBuilder(); for(int i = 0; i < sen.size(); i++) { sb.append(sen.get(i)); } String s = sb.toString().trim(); s = "##" + s + "##"; int l = s.length(); for(int i = 0; i < l - 2; i++) { String t = "t_" + s.substring(i, i + 2); sv.put(enc.encode(t), 1.0); } } if(USE_FRAMES) { for(int i = 0; i < sen.size(); i++) { String word = sen.get(i); String lemma = g.first.nodes[i+1].lemma; String pos = g.first.nodes[i+1].pos; if(lemma == null || lemma.equals("_")) lemma = word; if(pos.startsWith("N")) pos = "n"; else if(pos.startsWith("V")) pos = "v"; else if(pos.startsWith("JJ")) pos = "a"; else if(pos.startsWith("RB")) pos = "adv"; else pos = null; if(pos != null) { Set lus = framenet.getLUsByHeadPos(lemma, pos); if(lus != null) for(LexUnit lu: lus) sv.put(enc.encode("FN_" + lu.frame), 1.0); } } } if(USE_WB_FEATURES) { for(int i = 0; i < sen.size(); i++) { String word = sen.get(i); String pos = g.first.nodes[i+1].pos; if(pos.equals("MD") && !word.equals("will")) sv.put(enc.encode("WB_MD"), 1.0); else if(pos.startsWith("RB") && !word.equals("not")) sv.put(enc.encode("WB_RB"), 1.0); else if(pos.startsWith("PRP")) sv.put(enc.encode("WB_PRP"), 1.0); else if(pos.startsWith("JJ")) sv.put(enc.encode("WB_JJ"), 1.0); else if(pos.startsWith("CD")) sv.put(enc.encode("WB_CD"), 1.0); } } // unigrams if(USE_UNIGRAMS) for(int i = 0; i < sen.size(); i++) { String s = sen.get(i); if(REPLACE_LEMMAS) { String l = g.first.nodes[i+1].lemma; if(l != null && !l.equals("_")) s = l; } if(!REMOVE_STOP_WORDS || !STOP_WORD_SET.contains(s)) sv.put(enc.encode(sen.get(i)), 1.0); } if(USE_POS) for(int i = 0; i < sen.size(); i++) sv.put(enc.encode(g.first.nodes[i+1].pos), 1.0); // bigrams if(USE_BIGRAMS) for(int i = 0; i < sen.size() - 1; i++) { String s = sen.get(i) + "_" + sen.get(i + 1); sv.put(enc.encode(s), 1.0); } if(USE_POS_BIGRAMS) for(int i = 1; i < g.first.nodes.length - 1; i++) { String w1 = sen.get(i-1); String w2 = sen.get(i); String p1 = g.first.nodes[i].pos; String p2 = g.first.nodes[i+1].pos; sv.put(enc.encode(w1 + "_" + p2), 1.0); sv.put(enc.encode(p1 + "_" + w2), 1.0); } if(USE_SUBJLEX) { for(int i = 1; i < g.first.nodes.length; i++) { String w = sen.get(i-1); String p = g.first.nodes[i].pos; String l = g.first.nodes[i].lemma; String slClue = subjLex.lookup(w, p, l); if(slClue != null) { //System.err.println(l + "/" + p + ": slClue = " + slClue); sv.put(enc.encode(slClue), 1.0); } } } // predicate lemmas if(USE_PRED_SENSES) for(PAStructure pa: g.third) sv.put(enc.encode(pa.lemma), 1.0); // verb--object tuples if(USE_VO_TUPLES) { DepGraph dg = g.first; for(int i = 1; i < dg.nodes.length; i++) { DepNode n = dg.nodes[i]; DepNode p = n.parents[0]; if(n.relations[0].equals("OBJ")) { String s = p.word + "_OBJ_" + n.word; s = s.toLowerCase(); sv.put(enc.encode(s), 1.0); } /*else if(n.relations[0].equals("SBJ")) { String s = p.word + "_SBJ_" + n.word; s = s.toLowerCase(); sv.put(enc.encode(s), 1.0); }*/ } } // syntactic functions if(USE_GFS) { DepGraph dg = g.first; for(int i = 1; i < dg.nodes.length; i++) sv.put(enc.encode(dg.nodes[i].relations[0]), 1.0); } if(USE_RAW_PA_TUPLES) for(PAStructure pa: g.third) { for(int i = 0; i < pa.args.size(); i++) { String l = pa.argLabels.get(i); DepNode a = pa.args.get(i); if(l.matches("A0|A1|A2")) { String s = pa.lemma + "_" + l + "_" + a.word; s = s.toLowerCase(); sv.put(enc.encode(s), 1.0); } } } // predicate--argument tuples if(USE_PA_TUPLES) { ArrayList ns = SRLPostProcess.processPAs(g.third, g.first, propBank, nomBank); for(SemNode n: ns) { if(!(n instanceof EventSemNode)) continue; EventSemNode en = (EventSemNode) n; for(ArgLink l: en.args) { if(l.rid.matches("A0|A1|A2|A3")) { String as; if(l.arg instanceof EventSemNode) as = ((EventSemNode) l.arg).lemma; else as = ((TokenSemNode) l.arg).word; String ps = en.lemma; String rel = l.rid; /*if(l.vn != null) rel = l.vn; else rel = l.rid;*/ String tuple = ps + "_" + rel + "_" + as; String tuple1 = en.lemma + "_" + rel; //String tuple2 = rel + "_" + as; sv.put(enc.encode(tuple), 1.0); sv.put(enc.encode(tuple1), 1.0); //sv.put(enc.encode(tuple2), 1.0); /*String tuple0a = ps + "_" + as; sv.put(enc.encode(tuple0a), 1.0); String tuple0b = "_" + as; sv.put(enc.encode(tuple0b), 1.0);*/ } } } } if(USE_JUSSI_FEATURES) { // tense shift for(int i = 1; i < g.first.nodes.length; i++) { DepNode n = g.first.nodes[i]; DepNode p = n.parents[0]; if(p.position == 0) continue; if(!n.relations[0].equals("OBJ")) continue; if(n.pos.equals("VBD") && p.pos.matches("VBP|VBZ") || n.pos.matches("VBP|VBZ") && p.pos.equals("VBD")) { sv.put(enc.encode("#tenseshift"), 1.0); } } } sv.sortIndices(); sv.trim(); return sv; } private static void weightIDF(SparseVector sv, int N, IntHistogram dfs) { for(int i = 0; i < sv.index; i++) { int df = dfs.getFrequency((int) sv.keys[i]); double idf = Math.log(N) - Math.log(df); sv.values[i] *= idf; } } static void normalize2(SparseVector sv) { double sum = 0; for(int i = 0; i < sv.index; i++) sum += sv.values[i] * sv.values[i]; double isum = 1.0 / Math.sqrt(sum); // obs tidigare inte sqrt!!! for(int i = 0; i < sv.index; i++) sv.values[i] *= isum; } private static final boolean USE_POLARITY = false; static final int SUBJ = 1, NO_SUBJ = 2; static final int POS_SUBJ = 11, NEG_SUBJ = 12, OTHER_SUBJ = 13; public static void train_cv(String[] argv) { String fileName = argv[1]; boolean testing = Boolean.parseBoolean(argv[2]); String testsetFilelist = argv[3]; String synsemFileName = argv[4]; String pbDir = argv[5]; String nbDir = argv[6]; String fnFile = argv[7]; String fnRelFile = argv[8]; String subjLexFile = argv[9]; int nfolds = Integer.parseInt(argv[10]); boolean idf = Boolean.parseBoolean(argv[11]); String outFile = argv[12]; String modelName = argv[13]; String algName = argv[14]; String algArgs = argv[15]; try { AlgorithmFactory.setVerbosity(0); HashSet testsetFiles = new HashSet(); BufferedReader tsbr = new BufferedReader(new FileReader(testsetFilelist)); String line = tsbr.readLine(); while(line != null) { line = line.trim(); testsetFiles.add(line); line = tsbr.readLine(); } ArrayList> ts = new ArrayList(); ArrayList> testset = new ArrayList(); ArrayList> tsP = new ArrayList(); BufferedReader br = new BufferedReader(new FileReader(fileName)); PrintWriter out = null; if(outFile != null) out = new PrintWriter(new FileWriter(outFile)); Scanner sc = new Scanner(Util.openFileStream(synsemFileName)); LexicalDB propBank = new LexicalDB(pbDir); LexicalDB nomBank = new LexicalDB(nbDir); FrameNet fn = null; if(USE_FRAMES) fn = new FrameNet(fnFile, fnRelFile); SubjectivityLexicon subjLex = new SubjectivityLexicon(subjLexFile); SymbolEncoder enc = new SymbolEncoder(); boolean inTestSet = false; boolean prevSubjective = false; while(true) { line = br.readLine(); if(line == null) break; if(line.startsWith(" sen = readSentence(br); //System.out.println(isSubjective + "\t" + sen); Triple> g = CoNLL2008Format.readNextGraph(sc); SparseVector sv = representSentence(sen, g, propBank, nomBank, fn, prevSubjective, subjLex, enc); prevSubjective = isSubjective; if(SET) for(int j = 0; j < sv.index; j++) sv.values[j] = 1.0; if(NORMALIZE) { normalize2(sv); /*double sqrlen = SparseVector.sortedSqrLength(sv); if(Math.abs(sqrlen - 1.0) > 1e-12) throw new RuntimeException("!!!");*/ } //System.out.println(sv); if(USE_POLARITY) { int cls; if(!isSubjective) cls = NO_SUBJ; else if(pol.equals("O")) cls = OTHER_SUBJ; else if(pol.equals("P")) cls = POS_SUBJ; else if(pol.equals("N")) cls = NEG_SUBJ; else throw new RuntimeException("!!!"); if(!testing || !inTestSet) { ts.add(new Pair(sv, isSubjective? SUBJ: NO_SUBJ)); tsP.add(new Pair(sv, cls)); } else testset.add(new Pair(sv, isSubjective? SUBJ: NO_SUBJ)); } else { if(!testing || !inTestSet) ts.add(new Pair(sv, isSubjective? SUBJ: NO_SUBJ)); else testset.add(new Pair(sv, isSubjective? SUBJ: NO_SUBJ)); } } if(DUMP_SVMLIGHT) { if(true) throw new RuntimeException("unimplemented"); PrintWriter pw = new PrintWriter(new FileWriter(SVMLIGHT_FILE)); for(Pair p: ts) { pw.print(p.right.equals(SUBJ)? "+1": "-1"); for(int i = 0; i < p.left.index; i++) pw.printf(" %d:%f", p.left.keys[i], p.left.values[i]); pw.println(); } pw.close(); System.out.println("Wrote SVMlight training data."); System.exit(0); } /*for(Pair p: ts) weightIDF(p.left, ts.size(), dfs); for(Pair p: ts) normalize2(p.left);*/ /*for(Pair p: ts) { System.out.print(p.right? "+1": "-1"); for(int i = 0; i < p.left.index; i++) System.out.print(" " + p.left.keys[i] + ":" + p.left.values[i]); System.out.println(); }*/ ConfusionMatrix cm = new ConfusionMatrix(); int ntp = 0, nfp = 0, ntn = 0, nfn = 0; int ncorr = 0; if(testing) nfolds = 1; if(false) // inte hela korpusen utan bara det lilla ts = testset; int foldSize = ts.size() / nfolds; //System.out.println("foldSize = " + foldSize); int count = 0; HashSet fpSenIndices = new HashSet(); HashSet fnSenIndices = new HashSet(); ArrayList range = new ArrayList(); range.add(SUBJ); range.add(NO_SUBJ); double[] scores = new double[2]; ArrayList scoreList = new ArrayList(); Random rand = new Random(0); for(int i = 0; i < nfolds; i++) { System.out.println("Fold " + (i + 1) + "."); ArrayList> trs = new ArrayList(); ArrayList> tes = new ArrayList(); ArrayList> trsP = new ArrayList(); ArrayList> tesP = new ArrayList(); if(!testing) { for(int j = 0; j < i*foldSize; j++) { Pair p = ts.get(j); trs.add(new Pair(p.left.copy(), p.right)); if(USE_POLARITY) { Pair pP = tsP.get(j); if(p.right.equals(SUBJ)) trsP.add(new Pair(pP.left.copy(), pP.right)); } } int fend = (i == nfolds - 1)? ts.size(): (i + 1)*foldSize; for(int j = i*foldSize; j < fend; j++) { Pair p = ts.get(j); tes.add(new Pair(p.left.copy(), p.right)); if(USE_POLARITY) { Pair pP = tsP.get(j); tesP.add(new Pair(pP.left.copy(), pP.right)); } } for(int j = fend; j < ts.size(); j++) { Pair p = ts.get(j); trs.add(new Pair(p.left.copy(), p.right)); if(USE_POLARITY) { Pair pP = tsP.get(j); if(p.right.equals(SUBJ)) trsP.add(new Pair(pP.left.copy(), pP.right)); } } } else { trs = ts; tes = testset; } if(idf) { IntHistogram dfs = new IntHistogram(); for(Pair p: trs) for(int j = 0; j < p.left.index; j++) dfs.add((int) p.left.keys[j]); for(Pair p: trs) weightIDF(p.left, trs.size(), dfs); for(Pair p: tes) weightIDF(p.left, trs.size(), dfs); } /*if(SET) { for(Pair p: trs) for(int j = 0; j < p.left.index; j++) p.left.values[j] = 1.0; for(Pair p: tes) for(int j = 0; j < p.left.index; j++) p.left.values[j] = 1.0; }*/ AlgorithmFactory af = new AlgorithmFactory(); LearningAlgorithm alg = af.create(algName, algArgs); EncodedMulticlassDefinition def = new EncodedMulticlassDefinition(new int[] { SUBJ, NO_SUBJ }); Classifier cl = alg.train(def, trs); Classifier clP = null; if(USE_POLARITY) clP = alg.train(def, trsP); boolean prevSubjectiveGuess = false; int PREV_TRUE = enc.encode("prev_true"); int PREV_FALSE = enc.encode("prev_false"); for(int j = 0; j < tes.size(); j++) { Pair p = tes.get(j); Pair pP = null; if(USE_POLARITY) pP = tesP.get(j); count++; if(true) { for(int k = 0; k < p.left.index; k++) { if(p.left.keys[k] == PREV_TRUE || p.left.keys[k] == PREV_FALSE) { if(prevSubjectiveGuess) p.left.keys[k] = PREV_TRUE; else p.left.keys[k] = PREV_FALSE; } } p.left.sortIndices(); } //int guess = cl.classify(p.left); cl.computeScoresRestricted(p.left, range, scores); int pos = CollectionUtils.maxIndex(scores, scores.length); double score = scores[pos]; int guess = range.get(pos); if(!p.right.equals(guess)) score *= -1; scoreList.add(new DoubleIntPair(score, count)); if(out != null) out.println(p.right + "\t" + guess); if(p.right.equals(guess)) ncorr++; if(!USE_POLARITY) { if(guess == SUBJ && p.right.equals(SUBJ)) ntp++; else if(guess == NO_SUBJ && p.right.equals(SUBJ)) nfn++; else if(guess == SUBJ && p.right.equals(NO_SUBJ)) nfp++; else ntn++; cm.add(p.right, guess); } else { if(guess == SUBJ) guess = clP.classify(pP.left); //if(guess == SUBJ) // guess = rand.nextInt(3) + 11; cm.add(pP.right, guess); } prevSubjectiveGuess = guess == SUBJ; if(guess == NO_SUBJ && p.right.equals(SUBJ)) { fnSenIndices.add(count); //System.out.println("FN: p = " + p); } else if(guess == SUBJ && p.right.equals(NO_SUBJ)) { fpSenIndices.add(count); //System.out.println("FP: p = " + p); } } } out.close(); /* if(SET) { for(Pair p: ts) for(int j = 0; j < p.left.index; j++) p.left.values[j] = 1.0; }*/ if(true) { AlgorithmFactory af = new AlgorithmFactory(); AlgorithmFactory.setVerbosity(0); LearningAlgorithm alg = af.create(algName, algArgs); EncodedMulticlassDefinition def = new EncodedMulticlassDefinition(new int[] { SUBJ, NO_SUBJ }); Classifier cl = alg.train(def, ts); // TODO idf? if(false) printWeights(enc, cl); if(true) { enc.freeze(); ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(modelName)); oos.writeObject(enc); oos.writeObject(cl); if(USE_SUBJLEX) oos.writeObject(subjLex); oos.close(); } } double s_pr = (double) ntp / (ntp + nfp); double s_re = (double) ntp / (ntp + nfn); double s_f1 = 2*s_pr*s_re / (s_pr + s_re); double s_acc = (double) ncorr / count; if(!USE_POLARITY) { System.out.printf("S pr = %d / %d = %f\n", ntp, ntp + nfp, s_pr); System.out.printf("S re = %d / %d = %f\n", ntp, ntp + nfn, s_re); System.out.printf("S f1 = %f\n", s_f1); } System.out.printf("S acc = %d / %d = %f\n", ncorr, count, s_acc); System.out.println(cm); if(false) printErrorSentences(fileName, fpSenIndices, fnSenIndices, ts, scoreList); } catch(Exception e) { e.printStackTrace(); System.exit(1); } } public static void printSentences(String[] argv) { String fileName = argv[1]; try { BufferedReader br = new BufferedReader(new FileReader(fileName)); SymbolEncoder enc = new SymbolEncoder(); String line; while(true) { line = br.readLine(); if(line == null) break; if(line.startsWith(""); ArrayList sen = readSentence(br); representSentence(sen, null, null, null, null, false, null, enc); for(int i = 0; i < sen.size(); i++) { if(i > 0) System.out.print(" "); System.out.print(sen.get(i)); } System.out.println(); } } catch(Exception e) { e.printStackTrace(); System.exit(1); } } private static void printWeights(SymbolEncoder enc, Classifier cl) { ArrayList encInv = enc.inverse(); int max = enc.size(); SparseVector sv = new SparseVector(); ArrayList range = new ArrayList(); range.add(SUBJ); range.add(NO_SUBJ); double[] scores = new double[2]; cl.computeScoresRestricted(new SparseVector(), range, scores); double bias = scores[0]; System.out.println("bias = " + bias); ArrayList> l = new ArrayList(); sv.index = 1; sv.values[0] = 1.0; for(int i = 1; i < max; i++) { sv.keys[0] = i; cl.computeScoresRestricted(sv, range, scores); String k = encInv.get(i); l.add(new DoubleObjPair(scores[0] - bias, k)); //System.out.println(i + "\t" + scores[0] + "\t" + scores[1]); } Collections.sort(l, new NegatedComparator(DoubleObjPair.BY_LEFT)); for(DoubleObjPair p: l) System.out.println(p); } private static void printErrorSentences(String filename, HashSet fpIndices, HashSet fnIndices, ArrayList> ts, ArrayList scoreList) throws IOException { BufferedReader br = new BufferedReader(new FileReader(filename)); String line = null; int count = 0; ArrayList sentences = new ArrayList(); while(true) { count++; line = br.readLine(); if(line == null) break; if(line.startsWith(" sen = readSentence(br); sentences.add((isSubjective? "Subjective": "Objective") + "\t" + Strings.join(sen, " ")); //if(fpIndices.contains(count)) // System.out.println("FP\t" + sen); //else if(fnIndices.contains(count)) // System.out.println("FN\t " + sen); //else // System.out.println("OK\t " + sen + "\t" + ts.get(count-1)); } Collections.sort(scoreList, DoubleIntPair.BY_LEFT); for(DoubleIntPair p: scoreList) { System.out.println(p.left + "\t" + sentences.get(p.right - 1)); } } private static String getSubjectivityClue(DepNode n, SubjectivityLexicon subjLex) { return subjLex.lookup(n.word, n.pos, n.lemma); } private static String[] getSubjectivityClues(DepGraph dg, SubjectivityLexicon subjLex) { String[] out = new String[dg.nodes.length]; for(int i = 1; i < dg.nodes.length; i++) out[i] = getSubjectivityClue(dg.nodes[i], subjLex); return out; } private static void modifyTreeSubjNodes(DepGraph dg, SubjectivityLexicon subjLex) { for(int i = 1; i < dg.nodes.length; i++) { String clue = getSubjectivityClue(dg.nodes[i], subjLex); if(clue != null) dg.nodes[i].word = clue; //out[i] = getSubjectivityClue(dg.nodes[i], subjLex); } } private static String bowRepr(DepGraph dg, SubjectivityLexicon subjLex) { StringBuilder out = new StringBuilder(); out.append("(BOW "); for(int i = 1; i < dg.nodes.length; i++) { String w = dg.nodes[i].word; w = w.toLowerCase(); w = w.replaceAll("[1-9]", "0"); w = w.replaceAll("\\(", "-LRB-"); w = w.replaceAll("\\)", "-RRB-"); w = w.replaceAll("\\[", "-LSB-"); w = w.replaceAll("\\]", "-RSB-"); w = w.replaceAll("\\{", "-LCB-"); w = w.replaceAll("\\}", "-RCB-"); String slClue = getSubjectivityClue(dg.nodes[i], subjLex); if(USE_SUBJLEX) if(slClue != null) out.append("(#b_" + slClue + ")"); out.append("(" + w + ")"); if(USE_SUBJLEX) if(slClue != null) out.append("(#e_" + slClue + ")"); } out.append(")"); return out.toString(); } private static String posRepr(DepGraph dg) { StringBuilder out = new StringBuilder(); out.append("(BOP "); for(int i = 1; i < dg.nodes.length; i++) { String w = dg.nodes[i].pos; w = w.toLowerCase(); w = w.replaceAll("\\(", "-LRB-"); w = w.replaceAll("\\)", "-RRB-"); out.append("(" + w + ")"); } out.append(")"); return out.toString(); } public static void tkOutput(String[] argv) { boolean usePolarity = Boolean.parseBoolean(argv[1]); boolean multiclass = Boolean.parseBoolean(argv[2]); if(!usePolarity && multiclass) throw new RuntimeException("illegal setting"); String fileName = argv[3]; String testsetFilelist = argv[4]; String synsemFileName = argv[5]; String constFileName = argv[6]; String tkTrainOutputFile = argv[7]; String tkTestOutputFile = argv[8]; String pbDir = argv[9]; String nbDir = argv[10]; String subjLexFile = argv[11]; try { HashSet testsetFiles = new HashSet(); BufferedReader tsbr = new BufferedReader(new FileReader(testsetFilelist)); String line = tsbr.readLine(); while(line != null) { line = line.trim(); testsetFiles.add(line); line = tsbr.readLine(); } BufferedReader br = new BufferedReader(new FileReader(fileName)); PrintWriter pwTrain = new PrintWriter(new FileWriter(tkTrainOutputFile)); PrintWriter pwTest = new PrintWriter(new FileWriter(tkTestOutputFile)); BufferedReader cbr = new BufferedReader(new FileReader(constFileName)); Scanner sc = new Scanner(new File(synsemFileName)); LexicalDB propBank = new LexicalDB(pbDir); LexicalDB nomBank = new LexicalDB(nbDir); SubjectivityLexicon subjLex = new SubjectivityLexicon(subjLexFile); PrintWriter pw = null; while(true) { line = br.readLine(); if(line == null) break; if(line.startsWith(" 1566 1023 1144 13 43 31 29 */ if(line.contains("pol=\"positive\"") || line.contains("pol=\"uncertain-positive\"")) pol = +1; else if(line.contains("pol=\"negative\"") || line.contains("pol=\"uncertain-negative\"")) pol = -1; if(multiclass) { if(line.contains("hasPos=\"true\"")) hasPos = true; if(line.contains("hasNeg=\"true\"")) hasNeg = true; if(line.contains("hasNeu=\"true\"")) hasNeu = true; } } //if(isSubjective) // System.err.println("*"); ArrayList sen = readSentence(br); Triple> g = CoNLL2008Format.readNextGraph(sc); String constRepr = cbr.readLine(); if(pw == pwTrain && isStrangeGraph(g.first)) { continue; } if(pw == pwTrain && constRepr.equals("(S0 FAILED)")) { continue; } if(!usePolarity) { if(isSubjective) pw.print("+1"); else pw.print("-1"); } else if(!multiclass) { if(!isSubjective) pw.print("0"); else pw.print(pol + 2); } else { if(hasPos) pw.print("+"); if(hasNeu) pw.print("0"); if(hasNeg) pw.print("-"); } String bowRepr = bowRepr(g.first, subjLex); pw.print(" |BT| " + bowRepr); String posRepr = posRepr(g.first); //modifyTreeSubjNodes(g.first, subjLex); //String depRepr = Linearizer.linearize1(g.first); String depRepr = null; if(true) throw new RuntimeException("had to comment out due to build path cycle"); pw.print(" |BT| " + posRepr); pw.print(" |BT| " + depRepr); pw.print(" |BT| " + constRepr.toLowerCase()); ArrayList sns = SRLPostProcess.processPAs(g.third, g.first, propBank, nomBank); String semRepr1 = TKLinearizer.flatTreeRepresentation(sns); String semRepr2 = TKLinearizer.flatTreeRepresentation2(sns); pw.print(" |BT| " + semRepr1); pw.print(" |BT| " + semRepr2); pw.println(" |ET|"); } pwTrain.close(); pwTest.close(); } catch(Exception e) { e.printStackTrace(); System.exit(1); } } public static void dirs(String[] argv) { boolean usePolarity = Boolean.parseBoolean(argv[1]); boolean multiclass = Boolean.parseBoolean(argv[2]); if(!usePolarity || !multiclass) throw new RuntimeException("illegal setting"); String fileName = argv[3]; String testsetFilelist = argv[4]; String trainOutDir = argv[5]; String testOutDir = argv[6]; try { new File(trainOutDir).mkdir(); new File(trainOutDir + "/pos").mkdir(); new File(trainOutDir + "/neu").mkdir(); new File(trainOutDir + "/neg").mkdir(); new File(testOutDir).mkdir(); new File(testOutDir + "/pos").mkdir(); new File(testOutDir + "/neu").mkdir(); new File(testOutDir + "/neg").mkdir(); HashSet testsetFiles = new HashSet(); BufferedReader tsbr = new BufferedReader(new FileReader(testsetFilelist)); String line = tsbr.readLine(); while(line != null) { line = line.trim(); testsetFiles.add(line); line = tsbr.readLine(); } BufferedReader br = new BufferedReader(new FileReader(fileName)); int count = 0; String outDir = null; while(true) { line = br.readLine(); if(line == null) break; if(line.startsWith(" sen = readSentence(br); if(!usePolarity) { } else if(!multiclass) { } else { if(hasPos) printSentenceToFile(outDir + "/pos/" + count, sen); if(hasNeu) printSentenceToFile(outDir + "/neu/" + count, sen); if(hasNeg) printSentenceToFile(outDir + "/neg/" + count, sen); } } } catch(Exception e) { e.printStackTrace(); System.exit(1); } } private static void printSentenceToFile(String file, ArrayList tokens) throws IOException { //System.out.println("file = " + file); PrintWriter pw = new PrintWriter(new FileWriter(file)); for(String s: tokens) pw.print(s + " "); pw.println(); pw.close(); } private static boolean isStrangeGraph(DepGraph dg) { int n = dg.getMaxBranching(); return n > 8; } private static void printWeightsInSentence(ArrayList sen, SymbolEncoder enc, Classifier cl) { SparseVector sv = new SparseVector(); ArrayList range = new ArrayList(); range.add(SUBJ); range.add(NO_SUBJ); double[] scores = new double[2]; cl.computeScoresRestricted(new SparseVector(), range, scores); double bias = scores[0]; System.out.println("bias = " + bias); ArrayList> l = new ArrayList(); sv.index = 1; sv.values[0] = 1.0; for(String s: sen) { sv.keys[0] = enc.encode(s); cl.computeScoresRestricted(sv, range, scores); System.out.println(s + "\t" + (scores[0] - bias)); } } public static void run(String[] argv) { String modelName = argv[1]; boolean printWeights = false; if(argv.length > 2) printWeights = Boolean.parseBoolean(argv[2]); try { ObjectInputStream ois = new ObjectInputStream(new FileInputStream(modelName)); SymbolEncoder enc = (SymbolEncoder) ois.readObject(); Classifier cl = (Classifier) ois.readObject(); SubjectivityLexicon subjLex = null; if(USE_SUBJLEX) subjLex = (SubjectivityLexicon) ois.readObject(); BufferedReader br = new BufferedReader(new InputStreamReader(System.in)); String line = br.readLine(); while(line != null) { line = line.trim(); if(line.equals("")) break; String[] ts = line.split(" "); ArrayList sen = new ArrayList(); for(String t: ts) sen.add(t); SparseVector sv = representSentence(sen, null, null, null, null, false, subjLex, enc); if(SET) for(int j = 0; j < sv.index; j++) sv.values[j] = 1.0; if(NORMALIZE) { normalize2(sv); /*double sqrlen = SparseVector.sortedSqrLength(sv); if(Math.abs(sqrlen - 1.0) > 1e-12) throw new RuntimeException("!!!");*/ } int guess = cl.classify(sv); System.out.println(guess == SUBJ? "true": "false"); if(printWeights) printWeightsInSentence(sen, enc, cl); line = br.readLine(); } } catch(Exception e) { e.printStackTrace(); System.exit(1); } } public static void main(String[] argv) { if(argv[0].equals("-cv")) train_cv(argv); else if(argv[0].equals("-tk")) tkOutput(argv); else if(argv[0].equals("-run")) run(argv); else if(argv[0].equals("-dirs")) dirs(argv); else if(argv[0].equals("-printSentences")) printSentences(argv); else throw new RuntimeException("illegal parameters"); } public static boolean classifySentence(SubjectivityLexicon sl, Triple> tr, Classifier cl, SymbolEncoder enc) throws IOException { ArrayList sen = new ArrayList(); for(int i = 1; i < tr.first.nodes.length; i++) sen.add(tr.first.nodes[i].word); SparseVector sv = representSentence(sen, tr, null, null, null, false, sl, enc); int guess = cl.classify(sv); return guess == SUBJ; } }