Naive Bayes Classifier
Document Classification using Naive Bayes.
ナイーブベイズによる文書分類
see also:
http://en.wikipedia.org/wiki/Naive_Bayes_classifier
http://rest-term.com/archives/2925/
/**
* Copyright wellflat ( http://wonderfl.net/user/wellflat )
* MIT License ( http://www.opensource.org/licenses/mit-license.php )
* Downloaded from: http://wonderfl.net/c/vT8x
*/
package {
import flash.display.Graphics;
import flash.display.Sprite;
import flash.text.TextField;
import flash.text.TextFormat;
import flash.utils.Dictionary;
[SWF(width="465", height="465", backgroundColor="#000000")]
public class Main extends Sprite {
private var nb:NaiveBayes;
private var tf:TextField;
public function Main():void {
p("*** Document Classification using Naive Bayes ***\n");
nb = new NaiveBayes(getFeatures);
sampleTrain(nb); // training
p(" - training data - ");
p("Hello how are you? -> good");
p("make quick money to live a good life -> bad");
p("the quick brown fox jumps -> good");
p("Dive into HTML5 -> good");
p("HTML5 vs. Flash comparison -> bad");
p("");
p("P(quick fox|good) = " + nb.getProb("quick fox", "good").toString());
p("P(quick fox|bad) = " + nb.getProb("quick fox", "bad").toString());
p("class 'quick fox': " + nb.classify("quick fox"));
p("");
p("P(quick money|good) = " + nb.getProb("quick money", "good").toString());
p("P(quick money|bad) = " + nb.getProb("quick money", "bad").toString());
p("class 'quick money': " + nb.classify("quick money"));
p("");
nb.setThreshold("bad", 3.0);
p("P(HTML5 Flash|good) = " + nb.getProb("HTML5 Flash", "good").toString());
p("P(HTML5 Flash|bad) = " + nb.getProb("HTML5 Flash", "bad").toString());
p("class 'HTML5 Flash': " + nb.classify("HTML5 Flash"));
p("");
for(var i:int=0; i<10; i++) { // more training
sampleTrain(nb);
}
p("P(HTML5 Flash|good) = " + nb.getProb("HTML5 Flash", "good").toString());
p("P(HTML5 Flash|bad) = " + nb.getProb("HTML5 Flash", "bad").toString());
p("class 'HTML5 Flash': " + nb.classify("HTML5 Flash"));
}
// sample training set below
private function sampleTrain(classifier:Classifier):void {
classifier.train("Hello how are you?", "good");
classifier.train("make quick money to live a good life", "bad");
classifier.train("the quick brown fox jumps", "good");
classifier.train("Dive into HTML5", "good");
classifier.train("HTML5 vs. Flash comparison", "bad");
}
// make features(bug-of-words) from document
private function getFeatures(document:String):Dictionary {
var delimiter:RegExp = /\W+/;
var words:Array = [];
var features:Dictionary = new Dictionary();
for each(var s:String in document.split(delimiter)) {
if(s.length > 2) {
words.push(s.toLowerCase());
}
}
for each(var w:String in words) {
features[w] = 1;
}
return features;
}
private function p(str:String):void {
if(!tf) {
var g:Graphics = this.graphics;
g.beginFill(0x000000);
g.drawRect(0, 0, stage.stageWidth, stage.stageHeight);
g.endFill();
tf = new TextField();
tf.x = 10;
tf.y = 20;
tf.width = tf.height = 465;
tf.defaultTextFormat = new TextFormat('Courier New', 12, 0x00ff66, true);
addChild(tf);
}
tf.appendText(str + "\n");
}
}
}
import flash.utils.Dictionary;
// Base Classifier
class Classifier {
private var featureCount:Dictionary;
private var categoryCount:Dictionary;
protected var getFeatures:Function;
public function Classifier(getFeatures:Function) {
this.featureCount = new Dictionary();
this.categoryCount = new Dictionary();
this.getFeatures = getFeatures;
}
// Train
public function train(document:String, category:String):void {
var features:Dictionary = getFeatures(document);
for(var f:String in features) {
incrFeatureCount(f, category);
}
incrCategoryCount(category);
}
protected function getCategoryCount(category:String):Number {
if(categoryCount[category]) {
return categoryCount[category];
}
return 0.0;
}
protected function getFeatureCount(feature:String, category:String):Number {
if(!featureCount[feature][category]) {
return 0.0;
}
return Number(featureCount[feature][category]);
}
protected function getTotalCount():uint {
var cnt:uint = 0;
for each(var i:uint in categoryCount) {
cnt += i;
}
return cnt;
}
protected function getCategories():Array {
var categories:Array = [];
for(var c:String in categoryCount) {
categories.push(c);
}
return categories;
}
// P(feature|category)
protected function getFeatureProb(feature:String, category:String):Number {
if(getCategoryCount(category) == 0) {
return 0.0;
}
return getFeatureCount(feature, category)/getCategoryCount(category);
}
protected function getWeightedFeatureProb(feature:String, category:String,
weight:Number = 1.0, aprob:Number = 0.5):Number {
var basicProb:Number = getFeatureProb(feature, category);
var totals:Number = 0.0;
getCategories().forEach(function(c:String, index:int, arr:Array):void {
totals += getFeatureCount(feature, c);
});
return ((weight*aprob) + (totals*basicProb))/(weight + totals);
}
private function incrFeatureCount(feature:String, category:String):void {
if(!featureCount[feature]) {
featureCount[feature] = new Dictionary();
}
if(!featureCount[feature][category]) {
featureCount[feature][category] = 0;
}
featureCount[feature][category]++;
}
private function incrCategoryCount(category:String):void {
if(!categoryCount[category]) {
categoryCount[category] = 0;
}
categoryCount[category]++;
}
}
// Naive Bayes Classifier
class NaiveBayes extends Classifier {
private var thresholds:Dictionary;
public function NaiveBayes(getFeatures:Function) {
super(getFeatures);
thresholds = new Dictionary();
}
public function getThreshold(category:String):Number {
for(var c:String in thresholds) {
if(c == category) {
return thresholds[category];
}
}
return 1.0;
}
public function setThreshold(category:String, threshold:Number):void {
thresholds[category] = threshold;
}
// classify document into category
public function classify(document:String):String {
var probs:Dictionary = new Dictionary();
var max:Number = 0.0;
var best:String = null;
getCategories().forEach(function(c:String, index:int, arr:Array):void {
probs[c] = getProb(document, c);
if(probs[c] > max) {
max = probs[c];
best = c;
}
});
for(var c:String in probs) {
var second:Number = probs[c]*getThreshold(best);
if(c === best) continue;
if(second > probs[best]) return "unknown";
}
return best;
}
// Bayes' theorem
// P(category|document) = P(document|category)P(category)/P(document)
//
public function getProb(document:String, category:String):Number {
// P(category)
var categoryProb:Number = getCategoryCount(category)/getTotalCount();
// P(document|category)
var documentProb:Number = getDocumentProb(document, category);
return documentProb*categoryProb; // ignore P(document)
}
// P(document|category)
private function getDocumentProb(document:String, category:String):Number {
var features:Dictionary = getFeatures(document);
var prob:Number = 1.0;
for(var f:String in features) {
prob *= getWeightedFeatureProb(f, category);
}
return prob;
}
}