package mpqareader; import java.util.*; import se.lth.cs.nlp.nlputils.core.*; import se.lth.cs.nlp.nlputils.ml_long.*; public class IntensitiesSplitter extends LearningAlgorithm { private String algName; //private String[] algArgs; private String algArgs; public IntensitiesSplitter(String algName, String algArgs) { this.algName = algName; //this.algArgs = algArgs.split("\\s+"); this.algArgs = algArgs; System.out.println(this.algName); System.out.println(Arrays.asList(this.algArgs)); } private static final int[] RANGE = { BOWPolSubjSentClassifier.NONE, BOWPolSubjSentClassifier.LOW, BOWPolSubjSentClassifier.MEDIUM, BOWPolSubjSentClassifier.HIGH }; public Classifier train( IntensitiesDefinition problem, List> trainingSet) { AlgorithmFactory af = new AlgorithmFactory(); Classifier[] cls = new Classifier[3]; for(int i = 0; i < 3; i++) { EncodedMulticlassDefinition subSpec = new EncodedMulticlassDefinition(RANGE); LearningAlgorithm alg = af.create(algName, algArgs); ArrayList> ts = new ArrayList(); for(Pair p: trainingSet) { int subInt; switch(i) { case 0: subInt = p.right.pos; break; case 1: subInt = p.right.neu; break; case 2: subInt = p.right.neg; break; default: throw new RuntimeException("!!!"); } ts.add(new Pair(p.left.copy(), subInt)); } cls[i] = alg.train(subSpec, ts); } return new SplitIntensitiesClassifier(cls); } private static class SplitIntensitiesClassifier extends Classifier { private static final long serialVersionUID = 0L; Classifier poscl, neucl, negcl; SplitIntensitiesClassifier(Classifier[] cls) { poscl = cls[0]; neucl = cls[1]; negcl = cls[2]; } public Intensities classify(SparseVector sv) { int pos = poscl.classify(sv); int neu = neucl.classify(sv); int neg = negcl.classify(sv); return new Intensities(pos, neu, neg); } } }