Thanks Bob, the storing of the classifier into a blob field works great. I am
having a little trouble reading the object back out of the database though. I
think I just don't have the right type of object.
also if I understand correctly I am storing a compiled version of the
classifier, if this is true is my code wrong since it appears to retrieve the
classifier object then compile it again?
here is the full code
import com.aliasi.classify.Classifier;
import com.aliasi.classify.ClassifierEvaluator;
import com.aliasi.classify.ConfusionMatrix;
import com.aliasi.classify.DynamicLMClassifier;
import com.aliasi.classify.JointClassification;
import com.aliasi.classify.LMClassifier;
import com.aliasi.util.AbstractExternalizable;
import java.io.File;
import java.io.IOException;
import com.aliasi.util.Files;
import java.sql.*;
import static java.lang.System.out;
import java.io.Serializable;
import java.io.*;
public class ClassifyNews implements Serializable{
private static File TRAINING_DIR
= new File("../../data/fourNewsGroups/4news-train");
private static File TESTING_DIR
= new File("../../data/fourNewsGroups/4news-test");
private static String[] CATEGORIES
= { "soc.religion.christian",
"talk.religion.misc",
"alt.atheism",
"misc.forsale" };
private static int NGRAM_SIZE = 6;
public static void main(String[] args)
throws ClassNotFoundException, IOException {
Connection con=null;
DynamicLMClassifier classifier
= DynamicLMClassifier.createNGramProcess(CATEGORIES,NGRAM_SIZE);
/* here we will connect to the database to see if we already have a trained
classifier */
try
{
Class.forName("in.co.daffodil.db.rmi.RmiDaffodilDBDriver");
}
catch(java.lang.ClassNotFoundException e)
{
//message = e.getMessage();
//RaiseErr=1;
out.println("Error loading jdbc driver");
}
try
{
con = DriverManager.getConnection
("jdbc:daffodilDB://127.0.0.1:3456/onyun","daffodil","daffodil");
}
catch(SQLException e)
{
//message = e.getMessage();
//RaiseErr=1;
out.println("Error opening connection to OnYun database");
}
int NeedTraining=0;
try
{
String strSQL="select * from Collection";
Statement st = null;
ResultSet recordset=null;
st =
con.createStatement(ResultSet.TYPE_SCROLL_INSENSITIVE,ResultSet.CONCUR_READ_ONLY\
);
//open the recordset
recordset = st.executeQuery(strSQL);
recordset.first();
Blob collection = recordset.getBlob("Collection");
byte[] classifierBytes = collection.getBytes(1,(int) collection.length());
ByteArrayInputStream bytesIn= new ByteArrayInputStream(classifierBytes);
ObjectInputStream objIn = new ObjectInputStream(bytesIn);
classifier = (Classifier<CharSequence,JointClassification>)
objIn.readObject();
objIn.close();
out.println("Found Trained Classifier - Skipping Training");
recordset.close();
}
catch(SQLException e)
{
NeedTraining=1;
}
if(NeedTraining==1)
{
for(int i=0; i<CATEGORIES.length; ++i) {
File classDir = new File(TRAINING_DIR,CATEGORIES[i]);
if (!classDir.isDirectory()) {
String msg = "Could not find training directory="
+ classDir
+ "\nHave you unpacked 4 newsgroups?";
System.out.println(msg); // in case exception gets lost in shell
throw new IllegalArgumentException(msg);
}
String[] trainingFiles = classDir.list();
for (int j = 0; j < trainingFiles.length; ++j) {
File file = new File(classDir,trainingFiles[j]);
String text = Files.readFromFile(file,"ISO-8859-1");
System.out.println("Training on " + CATEGORIES[i] + "/" +
trainingFiles[j]);
classifier.train(CATEGORIES[i],text);
}
}
//compiling
// we created object so know it's safe
try
{
Class.forName("in.co.daffodil.db.rmi.RmiDaffodilDBDriver");
}
catch(java.lang.ClassNotFoundException e)
{
//message = e.getMessage();
//RaiseErr=1;
out.println("Error loading jdbc driver");
}
try
{
con = DriverManager.getConnection
("jdbc:daffodilDB://127.0.0.1:3456/onyun","daffodil","daffodil");
}
catch(SQLException e)
{
//message = e.getMessage();
//RaiseErr=1;
out.println("Error opening connection to OnYun database");
}
try
{
ByteArrayOutputStream bytesOut = new ByteArrayOutputStream();
ObjectOutputStream out = new ObjectOutputStream(bytesOut);
//out.writeObject(classifier);
classifier.compileTo(out);
out.close();
byte[] bytes = bytesOut.toByteArray();
PreparedStatement pstmt = con.prepareStatement("insert into Collection
(Collection) values(?)");
pstmt.setObject(1, bytes);
pstmt.executeUpdate();
}
catch(SQLException e)
{
//message = e.getMessage();
//RaiseErr=1;
out.println("Error running prepared statement");
return;
}
} // end of if need training
System.out.println("Compiling");
@SuppressWarnings("unchecked")
Classifier<CharSequence,JointClassification> compiledClassifier =
(Classifier<CharSequence,JointClassification>)
AbstractExternalizable.compile(classifier);
ClassifierEvaluator<CharSequence,JointClassification> evaluator= new
ClassifierEvaluator<CharSequence,JointClassification>(compiledClassifier,CATEGOR\
IES);
for(int i = 0; i < CATEGORIES.length; ++i) {
File classDir = new File(TESTING_DIR,CATEGORIES[i]);
String[] testingFiles = classDir.list();
for (int j=0; j<testingFiles.length; ++j) {
String text
= Files
.readFromFile(new
File(classDir,testingFiles[j]),"ISO-8859-1");
System.out.print("Testing on " + CATEGORIES[i] + "/" +
testingFiles[j] + " ");
evaluator.addCase(CATEGORIES[i],text);
JointClassification jc =
compiledClassifier.classify(text);
String bestCategory = jc.bestCategory();
String details = jc.toString();
System.out.println("Got best category of: " + bestCategory);
System.out.println(jc.toString());
System.out.println("---------------");
}
}
ConfusionMatrix confMatrix = evaluator.confusionMatrix();
System.out.println("Total Accuracy: " + confMatrix.totalAccuracy());
}
}