I fiddled with the code some more and got it to work.
here is the working code...
note : this is extremely fast especially since the blob created from the
original training articles is a little over 5MB
import com.aliasi.classify.Classifier;
import com.aliasi.classify.ClassifierEvaluator;
import com.aliasi.classify.ConfusionMatrix;
import com.aliasi.classify.DynamicLMClassifier;
import com.aliasi.classify.JointClassification;
import com.aliasi.classify.LMClassifier;
import com.aliasi.util.AbstractExternalizable;
import java.io.File;
import java.io.IOException;
import com.aliasi.util.Files;
import java.sql.*;
import static java.lang.System.out;
import java.io.Serializable;
import java.io.*;
public class ClassifyNews {
private static File TRAINING_DIR
= new File("../../data/fourNewsGroups/4news-train");
private static File TESTING_DIR
= new File("../../data/fourNewsGroups/4news-test");
private static String[] CATEGORIES
= { "soc.religion.christian",
"talk.religion.misc",
"alt.atheism",
"misc.forsale" };
private static int NGRAM_SIZE = 6;
public static void main(String[] args)
throws ClassNotFoundException, IOException {
Connection con=null;
DynamicLMClassifier classifier
= DynamicLMClassifier.createNGramProcess(CATEGORIES,NGRAM_SIZE);
Classifier<CharSequence,JointClassification> precompiledClassifier = null;
/* here we will connect to the database to see if we already have a trained
classifier */
try
{
Class.forName("in.co.daffodil.db.rmi.RmiDaffodilDBDriver");
}
catch(java.lang.ClassNotFoundException e)
{
//message = e.getMessage();
//RaiseErr=1;
out.println("Error loading jdbc driver");
}
try
{
con = DriverManager.getConnection
("jdbc:daffodilDB://127.0.0.1:3456/onyun","daffodil","daffodil");
}
catch(SQLException e)
{
//message = e.getMessage();
//RaiseErr=1;
out.println("Error opening connection to OnYun database");
}
int NeedTraining=0;
try
{
String strSQL="select * from Collection";
Statement st = null;
ResultSet recordset=null;
st =
con.createStatement(ResultSet.TYPE_SCROLL_INSENSITIVE,ResultSet.CONCUR_READ_ONLY\
);
//open the recordset
recordset = st.executeQuery(strSQL);
recordset.first();
Blob collection = recordset.getBlob("Collection");
byte[] classifierBytes = collection.getBytes(1,(int) collection.length());
ByteArrayInputStream bytesIn= new ByteArrayInputStream(classifierBytes);
ObjectInputStream objIn = new ObjectInputStream(bytesIn);
//classifier = (Classifier<CharSequence,JointClassification>)
objIn.readObject();
precompiledClassifier = (Classifier<CharSequence,JointClassification>)
objIn.readObject();
//out.println("length = " + collection.length());
objIn.close();
out.println("Found Trained Classifier - Skipping Training");
recordset.close();
}
catch(SQLException e)
{
NeedTraining=1;
}
if(NeedTraining==1)
{
for(int i=0; i<CATEGORIES.length; ++i) {
File classDir = new File(TRAINING_DIR,CATEGORIES[i]);
if (!classDir.isDirectory()) {
String msg = "Could not find training directory="
+ classDir
+ "\nHave you unpacked 4 newsgroups?";
System.out.println(msg); // in case exception gets lost in shell
throw new IllegalArgumentException(msg);
}
String[] trainingFiles = classDir.list();
for (int j = 0; j < trainingFiles.length; ++j) {
File file = new File(classDir,trainingFiles[j]);
String text = Files.readFromFile(file,"ISO-8859-1");
System.out.println("Training on " + CATEGORIES[i] + "/" +
trainingFiles[j]);
classifier.train(CATEGORIES[i],text);
}
}
//compiling
// we created object so know it's safe
try
{
Class.forName("in.co.daffodil.db.rmi.RmiDaffodilDBDriver");
}
catch(java.lang.ClassNotFoundException e)
{
//message = e.getMessage();
//RaiseErr=1;
out.println("Error loading jdbc driver");
}
try
{
con = DriverManager.getConnection
("jdbc:daffodilDB://127.0.0.1:3456/onyun","daffodil","daffodil");
}
catch(SQLException e)
{
//message = e.getMessage();
//RaiseErr=1;
out.println("Error opening connection to OnYun database");
}
try
{
ByteArrayOutputStream bytesOut = new ByteArrayOutputStream();
ObjectOutputStream out = new ObjectOutputStream(bytesOut);
//out.writeObject(classifier);
classifier.compileTo(out);
out.close();
byte[] bytes = bytesOut.toByteArray();
PreparedStatement pstmt = con.prepareStatement("insert into Collection
(Collection) values(?)");
pstmt.setObject(1, bytes);
pstmt.executeUpdate();
}
catch(SQLException e)
{
//message = e.getMessage();
//RaiseErr=1;
out.println("Error running prepared statement");
return;
}
} // end of if need training
System.out.println("Compiling");
@SuppressWarnings("unchecked")
Classifier<CharSequence,JointClassification> compiledClassifier =
(Classifier<CharSequence,JointClassification>)
AbstractExternalizable.compile(classifier);
if(NeedTraining==0)
compiledClassifier=precompiledClassifier;
ClassifierEvaluator<CharSequence,JointClassification> evaluator= new
ClassifierEvaluator<CharSequence,JointClassification>(compiledClassifier,CATEGOR\
IES);
for(int i = 0; i < CATEGORIES.length; ++i) {
File classDir = new File(TESTING_DIR,CATEGORIES[i]);
String[] testingFiles = classDir.list();
for (int j=0; j<testingFiles.length; ++j) {
String text
= Files
.readFromFile(new
File(classDir,testingFiles[j]),"ISO-8859-1");
System.out.print("Testing on " + CATEGORIES[i] + "/" +
testingFiles[j] + " ");
evaluator.addCase(CATEGORIES[i],text);
JointClassification jc =
compiledClassifier.classify(text);
String bestCategory = jc.bestCategory();
String details = jc.toString();
System.out.println("Got best category of: " + bestCategory);
System.out.println(jc.toString());
System.out.println("---------------");
}
}
ConfusionMatrix confMatrix = evaluator.confusionMatrix();
System.out.println("Total Accuracy: " + confMatrix.totalAccuracy());
}
}