001 package org.maltparser.parser.guide.instance; 002 003 import java.io.IOException; 004 import java.lang.reflect.Constructor; 005 import java.lang.reflect.InvocationTargetException; 006 import java.util.ArrayList; 007 import java.util.Formatter; 008 009 import org.maltparser.core.exception.MaltChainedException; 010 import org.maltparser.core.feature.FeatureVector; 011 import org.maltparser.core.feature.function.FeatureFunction; 012 import org.maltparser.core.feature.function.Modifiable; 013 import org.maltparser.core.syntaxgraph.DependencyStructure; 014 import org.maltparser.ml.LearningMethod; 015 import org.maltparser.parser.guide.ClassifierGuide; 016 import org.maltparser.parser.guide.GuideException; 017 import org.maltparser.parser.guide.Model; 018 import org.maltparser.parser.history.action.SingleDecision; 019 020 021 /** 022 023 @author Johan Hall 024 @since 1.0 025 */ 026 public class AtomicModel implements InstanceModel { 027 private Model parent; 028 private String modelName; 029 private FeatureVector featureVector; 030 private int index; 031 private int frequency = 0; 032 private LearningMethod method; 033 034 035 /** 036 * Constructs an atomic model. 037 * 038 * @param index the index of the atomic model (-1..n), where -1 is special value (used by a single model 039 * or the master divide model) and n is number of divide models. 040 * @param features the feature vector used by the atomic model. 041 * @param parent the parent guide model. 042 * @throws MaltChainedException 043 */ 044 public AtomicModel(int index, FeatureVector features, Model parent) throws MaltChainedException { 045 setParent(parent); 046 setIndex(index); 047 if (index == -1) { 048 setModelName(parent.getModelName()+"."); 049 } else { 050 setModelName(parent.getModelName()+"."+new Formatter().format("%03d", index)+"."); 051 } 052 setFeatures(features); 053 setFrequency(0); 054 initMethod(); 055 if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.BATCH && index == -1 && getGuide().getConfiguration().getConfigurationDir().getInfoFileWriter() != null) { 056 try { 057 getGuide().getConfiguration().getConfigurationDir().getInfoFileWriter().write(method.toString()); 058 getGuide().getConfiguration().getConfigurationDir().getInfoFileWriter().flush(); 059 } catch (IOException e) { 060 throw new GuideException("Could not write learner settings to the information file. ", e); 061 } 062 } 063 } 064 065 public void addInstance(SingleDecision decision) throws MaltChainedException { 066 try { 067 method.addInstance(decision, featureVector); 068 } catch (NullPointerException e) { 069 throw new GuideException("The learner cannot be found. ", e); 070 } 071 } 072 073 074 public void noMoreInstances() throws MaltChainedException { 075 try { 076 method.noMoreInstances(); 077 } catch (NullPointerException e) { 078 throw new GuideException("The learner cannot be found. ", e); 079 } 080 } 081 082 public void finalizeSentence(DependencyStructure dependencyGraph) throws MaltChainedException { 083 try { 084 method.finalizeSentence(dependencyGraph); 085 } catch (NullPointerException e) { 086 throw new GuideException("The learner cannot be found. ", e); 087 } 088 } 089 090 public boolean predict(SingleDecision decision) throws MaltChainedException { 091 try { 092 if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.BATCH) { 093 throw new GuideException("Cannot predict during batch training. "); 094 } 095 return method.predict(featureVector, decision); 096 } catch (NullPointerException e) { 097 throw new GuideException("The learner cannot be found. ", e); 098 } 099 } 100 101 public FeatureVector predictExtract(SingleDecision decision) throws MaltChainedException { 102 try { 103 if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.BATCH) { 104 throw new GuideException("Cannot predict during batch training. "); 105 } 106 if (method.predict(featureVector, decision)) { 107 return featureVector; 108 } 109 return null; 110 } catch (NullPointerException e) { 111 throw new GuideException("The learner cannot be found. ", e); 112 } 113 } 114 115 public FeatureVector extract() throws MaltChainedException { 116 return featureVector; 117 } 118 119 public void terminate() throws MaltChainedException { 120 if (method != null) { 121 method.terminate(); 122 method = null; 123 } 124 featureVector = null; 125 parent = null; 126 } 127 128 /** 129 * Moves all instance from this atomic model into the destination atomic model and add the divide feature. 130 * This method is used by the feature divide model to sum up all model below a certain threshold. 131 * 132 * @param model the destination atomic model 133 * @param divideFeature the divide feature 134 * @param divideFeatureIndexVector the divide feature index vector 135 * @throws MaltChainedException 136 */ 137 public void moveAllInstances(AtomicModel model, FeatureFunction divideFeature, ArrayList<Integer> divideFeatureIndexVector) throws MaltChainedException { 138 if (method == null) { 139 throw new GuideException("The learner cannot be found. "); 140 } else if (model == null) { 141 throw new GuideException("The guide model cannot be found. "); 142 } else if (divideFeature == null) { 143 throw new GuideException("The divide feature cannot be found. "); 144 } else if (divideFeatureIndexVector == null) { 145 throw new GuideException("The divide feature index vector cannot be found. "); 146 } 147 ((Modifiable)divideFeature).setFeatureValue(index); 148 method.moveAllInstances(model.getMethod(), divideFeature, divideFeatureIndexVector); 149 method.terminate(); 150 method = null; 151 } 152 153 /** 154 * Invokes the train() of the learning method 155 * 156 * @throws MaltChainedException 157 */ 158 public void train() throws MaltChainedException { 159 try { 160 method.train(featureVector); 161 method.terminate(); 162 method = null; 163 } catch (NullPointerException e) { 164 throw new GuideException("The learner cannot be found. ", e); 165 } 166 } 167 168 /** 169 * Initialize the learning method according to the option --learner-method. 170 * 171 * @throws MaltChainedException 172 */ 173 public void initMethod() throws MaltChainedException { 174 Class<?> clazz = (Class<?>)getGuide().getConfiguration().getOptionValue("guide", "learner"); 175 // if (clazz == org.maltparser.ml.libsvm.Libsvm.class && (Boolean)getGuide().getConfiguration().getOptionValue("malt0.4", "behavior") == true) { 176 // try { 177 // clazz = Class.forName("org.maltparser.ml.libsvm.malt04.LibsvmMalt04"); 178 // } catch (ClassNotFoundException e) { 179 // throw new GuideException("Could not find the class 'org.maltparser.ml.libsvm.malt04.LibsvmMalt04'. ", e); 180 // } 181 // } 182 Class<?>[] argTypes = { org.maltparser.parser.guide.instance.InstanceModel.class, java.lang.Integer.class }; 183 Object[] arguments = new Object[2]; 184 arguments[0] = this; 185 if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.CLASSIFY) { 186 arguments[1] = LearningMethod.CLASSIFY; 187 } else if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.BATCH) { 188 arguments[1] = LearningMethod.BATCH; 189 } 190 191 try { 192 Constructor<?> constructor = clazz.getConstructor(argTypes); 193 this.method = (LearningMethod)constructor.newInstance(arguments); 194 } catch (NoSuchMethodException e) { 195 throw new GuideException("The learner class '"+clazz.getName()+"' cannot be initialized. ", e); 196 } catch (InstantiationException e) { 197 throw new GuideException("The learner class '"+clazz.getName()+"' cannot be initialized. ", e); 198 } catch (IllegalAccessException e) { 199 throw new GuideException("The learner class '"+clazz.getName()+"' cannot be initialized. ", e); 200 } catch (InvocationTargetException e) { 201 throw new GuideException("The learner class '"+clazz.getName()+"' cannot be initialized. ", e); 202 } 203 } 204 205 206 207 /** 208 * Returns the parent guide model 209 * 210 * @return the parent guide model 211 */ 212 public Model getParent() throws MaltChainedException { 213 if (parent == null) { 214 throw new GuideException("The atomic model can only be used by a parent model. "); 215 } 216 return parent; 217 } 218 219 /** 220 * Sets the parent guide model 221 * 222 * @param parent the parent guide model 223 */ 224 protected void setParent(Model parent) { 225 this.parent = parent; 226 } 227 228 public String getModelName() { 229 return modelName; 230 } 231 232 /** 233 * Sets the name of the atomic model 234 * 235 * @param modelName the name of the atomic model 236 */ 237 protected void setModelName(String modelName) { 238 this.modelName = modelName; 239 } 240 241 /** 242 * Returns the feature vector used by this atomic model 243 * 244 * @return a feature vector object 245 */ 246 public FeatureVector getFeatures() { 247 return featureVector; 248 } 249 250 /** 251 * Sets the feature vector used by the atomic model. 252 * 253 * @param features a feature vector object 254 */ 255 protected void setFeatures(FeatureVector features) { 256 this.featureVector = features; 257 } 258 259 public ClassifierGuide getGuide() { 260 return parent.getGuide(); 261 } 262 263 /** 264 * Returns the index of the atomic model 265 * 266 * @return the index of the atomic model 267 */ 268 public int getIndex() { 269 return index; 270 } 271 272 /** 273 * Sets the index of the model (-1..n), where -1 is a special value. 274 * 275 * @param index index value (-1..n) of the atomic model 276 */ 277 protected void setIndex(int index) { 278 this.index = index; 279 } 280 281 /** 282 * Returns the frequency (number of instances) 283 * 284 * @return the frequency (number of instances) 285 */ 286 public int getFrequency() { 287 return frequency; 288 } 289 290 /** 291 * Increase the frequency by 1 292 */ 293 public void increaseFrequency() { 294 if (parent instanceof InstanceModel) { 295 ((InstanceModel)parent).increaseFrequency(); 296 } 297 frequency++; 298 } 299 300 public void decreaseFrequency() { 301 if (parent instanceof InstanceModel) { 302 ((InstanceModel)parent).decreaseFrequency(); 303 } 304 frequency--; 305 } 306 /** 307 * Sets the frequency (number of instances) 308 * 309 * @param frequency (number of instances) 310 */ 311 protected void setFrequency(int frequency) { 312 this.frequency = frequency; 313 } 314 315 /** 316 * Returns a learner object 317 * 318 * @return a learner object 319 */ 320 public LearningMethod getMethod() { 321 return method; 322 } 323 324 325 /* (non-Javadoc) 326 * @see java.lang.Object#toString() 327 */ 328 public String toString() { 329 final StringBuilder sb = new StringBuilder(); 330 sb.append(method.toString()); 331 return sb.toString(); 332 } 333 }