ぱーせぷとろんのがくしゅう
パーセプトロンさんが学習しているさまを眺めるだけのコード。勉強用。
/**
* Copyright jetbead ( http://wonderfl.net/user/jetbead )
* MIT License ( http://www.opensource.org/licenses/mit-license.php )
* Downloaded from: http://wonderfl.net/c/uD0w
*/
package
{
import flash.display.*;
import flash.events.*;
import flash.text.TextField;
public class Main extends Sprite
{
private const coord_scale:Number = 10;
private const coord_size:int = 200;
private var coord:Coord2DSprite;
private var perceptron:Perceptron;
private var cnt:int = 0;
public function Main():void
{
if (stage) init();
else addEventListener(Event.ADDED_TO_STAGE, init);
}
private function init(e:Event = null):void
{
removeEventListener(Event.ADDED_TO_STAGE, init);
stage.align = StageAlign.TOP_LEFT;
stage.scaleMode = StageScaleMode.NO_SCALE;
stage.frameRate = 30;
var tf:TextField = new TextField();
tf.x = 0;
tf.y = 0;
tf.width = 465;
tf.height = 50;
tf.selectable = false;
tf.text = "座標をクリックすると、正例(橙点)と負例(青点)を交互に入力できます。\n入力すると随時ランダムに点を選んで学習します。\n線形分離可能ならば、いつかは落ち着きます。たぶん。";
addChild(tf);
coord = new Coord2DSprite(coord_size, coord_size, coord_scale, coord_scale);
coord.x = 233 - coord_size / 2;
coord.y = 233 - coord_size / 2;
addChild(coord);
perceptron = new Perceptron(0.0);
addEventListener(Event.ENTER_FRAME, draw);
}
private function draw(e:Event):void {
if (coord.pointType.length == 0) return;
if (cnt < 2) { //2フレームぐらい余裕を持たせる
cnt++;
return;
}
cnt = 0;
//////////////////////////////////////////////////////////////////////
removeEventListener(Event.ENTER_FRAME, draw);
//ランダムな点を選んで、その点を使って学習させる
var idx:int = int(Math.random() * coord.pointType.length);
coord.selectPoint(coord.pointX[idx], coord.pointY[idx]);
perceptron.train(coord.pointType[idx], { "x":coord.pointX[idx], "y":coord.pointY[idx] } );
//現在の分離平面を表示
draw_w();
addEventListener(Event.ENTER_FRAME, draw);
}
//分離平面の表示
private function draw_w():void {
for (var i:int = 0; i < coord_size; i++) {
for (var j:int = 0; j < coord_size; j++) {
var nx:Number = ( i * 2 * coord_scale ) / coord_size - coord_scale;
var ny:Number = ( (coord_size - j) * 2 * coord_scale ) / coord_size - coord_scale;
var ret:Number = perceptron.predict( { "x":nx, "y":ny, "bias":1.0 } );
if (ret > 0){
coord.bg_data.setPixel(i, j, 0xffa500);
} else {
coord.bg_data.setPixel(i, j, 0x4169e1);
}
}
}
}
}
}
//座標処理用Sprite
import flash.display.Sprite;
import flash.display.Bitmap;
import flash.display.BitmapData;
import flash.events.MouseEvent;
class Coord2DSprite extends Sprite {
public var type:int = 1;
//クリックされたポイントの座標
public var pointType:Array;
public var pointX:Array;
public var pointY:Array;
//各部品
public var bg:Bitmap; //背景(分離後の色分けに使う用)
public var bg_data:BitmapData;
public var psp:Sprite; //ポイントを表示する用
public var circle:Sprite; //現在学習中のポイントを明示する用
//座標変換用
private var W:int;
private var H:int;
private var minX:Number;
private var maxX:Number;
private var minY:Number;
private var maxY:Number;
public function Coord2DSprite(W_:int, H_:int, scaleX_:Number = 1, scaleY_:Number = 1) {
W = W_;
H = H_;
minX = -scaleX_;
maxX = scaleX_;
minY = -scaleY_;
maxY = scaleY_;
pointType = new Array();
pointX = new Array();
pointY = new Array();
bg_data = new BitmapData(W, H);
bg = new Bitmap(bg_data);
this.addChild(bg);
psp = new Sprite();
this.addChild(psp);
circle = new Sprite();
this.addChild(circle);
for (var i:int = 0; i < W; i++) {
for (var j:int = 0; j < H; j++) {
bg_data.setPixel(i, j, 0xcccccc);
}
}
psp.graphics.lineStyle(1, 0x000000);
psp.graphics.moveTo(0, H_ / 2);
psp.graphics.lineTo(W_, H_ / 2);
psp.graphics.moveTo(W_ / 2, 0);
psp.graphics.lineTo(W_ / 2, H_);
circle.graphics.lineStyle(1, 0x000000);
circle.graphics.beginFill(0x000000, 0.5);
circle.graphics.drawCircle(0, 0, 4);
circle.graphics.endFill();
circle.visible = false;
this.addEventListener(MouseEvent.CLICK, onClick);
psp.addEventListener(MouseEvent.CLICK, onClick);
}
private function onClick(e:MouseEvent):void {
var nx:Number = ( e.target.mouseX * (maxX - minX) ) / W + minX;
var ny:Number = ( (H - e.target.mouseY) * (maxY - minY) ) / H + minY;
pointType.push(type);
pointX.push(nx);
pointY.push(ny);
if(type == 1){
psp.graphics.beginFill(0xffa500);
} else {
psp.graphics.beginFill(0x4169e1);
}
psp.graphics.drawCircle(e.target.mouseX, e.target.mouseY, 3);
psp.graphics.endFill();
if (type == 1) {
type = -1;
} else {
type = 1;
}
}
public function selectPoint(x:Number, y:Number):void {
var nx:int = ( (x - minX) / (maxX - minX) ) * W;
var ny:int = H - ( (y - minY) / (maxY - minY) ) * H;
circle.visible = true;
circle.x = nx;
circle.y = ny;
}
}
//パーセプトロンによる分類器
class Perceptron {
private var w:Object;
public var bias:Number;
private var margin:Number;
public function Perceptron(margin_:Number = 0.0, bias_:Number = 1.0) {
bias = bias_;
margin = margin_;
w = new Object();
w["bias"] = bias;
}
//予測
public function predict(x:Object):Number {
var ret:Number = 0;
for (var key:String in x) {
if(!isNaN(w[key])) {
ret += x[key] * w[key];
}
}
return ret;
}
//まとめて学習
public function train_all(t:Array, x:Array, loop:int):void {
for (var l:int = 0; l < loop; l++) {
for (var i:int = 0; i < t.length; i++) {
train(t[i], x[i]);
}
}
}
//1回分の学習(SGDによる)
public function train(t:int, x:Object):void {
if (isNaN(x["bias"])) {
x["bias"] = bias;
}
var f:Number = predict(x);
//trace( f.toString() + " <=> " + t.toString() + " : " + w["x"] + "," + w["y"] + "," + w["bias"]);
if (t * f <= margin) {
for (var key:String in x) {
if (isNaN(w[key])) {
w[key] = t * x[key];
} else {
w[key] += t * x[key];
}
}
}
}
}