Skip to content

Commit

Permalink
Fix examples: double-float -> single-float
Browse files Browse the repository at this point in the history
  • Loading branch information
masatoi committed Jun 20, 2021
1 parent 2d180c8 commit 5713d93
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 59 deletions.
6 changes: 3 additions & 3 deletions example/classification/cifar10.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
(defparameter n-class 10)

(defparameter x
(make-array '(50000 3072) :element-type 'double-float))
(make-array '(50000 3072) :element-type 'single-float))

(defparameter y
(make-array 50000 :element-type 'fixnum))

(defparameter x.t
(make-array '(10000 3072) :element-type 'double-float))
(make-array '(10000 3072) :element-type 'single-float))

(defparameter y.t
(make-array 10000 :element-type 'fixnum))
Expand All @@ -30,7 +30,7 @@
(loop for i from (* n 10000) below (* (1+ n) 10000) do
(setf (aref target i) (read-byte s))
(loop for j from 0 below 3072 do
(setf (aref datamatrix i j) (coerce (read-byte s) 'double-float))))
(setf (aref datamatrix i j) (coerce (read-byte s) 'single-float))))
'done))

(loop for i from 0 to 4 do
Expand Down
73 changes: 37 additions & 36 deletions example/classification/kmnist.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

;; KMNIST data
;; https://github.com/rois-codh/kmnist
(defparameter dir (asdf:system-relative-pathname :cl-random-forest "dataset/"))
(defparameter dir (asdf:system-relative-pathname :cl-random-forest "dataset/kmnist/"))
(defparameter mnist-dim 784)
(defparameter mnist-n-class 10)

Expand Down Expand Up @@ -63,14 +63,14 @@
(number-of-rows (slot-value mnist-dataset 'number-of-rows))
(number-of-columns (slot-value mnist-dataset 'number-of-columns))
(datamatrix (make-array (list number-of-images (* number-of-rows number-of-columns))
:element-type 'double-float
:initial-element 0d0)))
:element-type 'single-float
:initial-element 0.0)))
(loop for i from 0 below number-of-images
do (let ((row (lisp-binary:read-bytes (* number-of-rows number-of-columns)
in :element-type '(unsigned-byte 8))))
(loop for j from 0 below (* number-of-rows number-of-columns)
for r across row
do (setf (aref datamatrix i j) (/ r (if scaling? 255d0 1d0))))))
do (setf (aref datamatrix i j) (/ r (if scaling? 255.0 1.0))))))
datamatrix)))

(defun read-mnist-labels (file)
Expand All @@ -84,11 +84,6 @@
do (setf (aref target i) x))
target)))

(defun read-mnist-labels (file)
(lisp-binary:with-open-binary-file (in file :direction :input)
(let* ((mnist-labels (lisp-binary:read-binary 'mnist-labels in)))
(slot-value mnist-labels 'target))))

(defparameter mnist-datamatrix
(read-mnist-dataset (merge-pathnames "train-images-idx3-ubyte" dir) :scaling? t))

Expand Down Expand Up @@ -117,65 +112,71 @@
:max-depth 15 :n-trial 28 :min-region-samples 5))

;; Prediction
(predict-dtree mnist-dtree mnist-datamatrix 0) ; => 5 (correct)
(predict-dtree mnist-dtree mnist-datamatrix 0) ; => 8 (correct)

;; Testing with training data
(test-dtree mnist-dtree mnist-datamatrix mnist-target)

;; Accuracy: 90.37333%, Correct: 54224, Total: 60000
;; Accuracy: 82.450005%, Correct: 49470, Total: 60000

;; Testing with test data
(test-dtree mnist-dtree mnist-datamatrix-test mnist-target-test)
;; Accuracy: 81.52%, Correct: 8152, Total: 10000
;; Accuracy: 56.97%, Correct: 5697, Total: 10000

;;; Make Random Forest ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

;;; Enable/Disable parallelizaion
(setf lparallel:*kernel* (lparallel:make-kernel 4))
(setf lparallel:*kernel* nil)

;; 6.079 seconds (1 core), 2.116 seconds (4 core)
(defparameter mnist-forest
(make-forest mnist-n-class mnist-datamatrix mnist-target
:n-tree 500 :bagging-ratio 0.1 :max-depth 10 :n-trial 10 :min-region-samples 5))
;; 2.987 seconds (4 core)
(time
(defparameter mnist-forest
(make-forest mnist-n-class mnist-datamatrix mnist-target
:n-tree 500 :bagging-ratio 0.1 :max-depth 10 :n-trial 10 :min-region-samples 5)))

;; Prediction
(predict-forest mnist-forest mnist-datamatrix 0) ; => 5 (correct)
(predict-forest mnist-forest mnist-datamatrix 0) ; => 8 (correct)

;; Testing with test data
;; 4.786 seconds, Accuracy: 93.38%

;; Accuracy: 69.4%, Correct: 6940, Total: 10000 (4.775 seconds)
(test-forest mnist-forest mnist-datamatrix-test mnist-target-test)

;; 42.717 seconds (1 core), 13.24 seconds (4 core)
(defparameter mnist-forest-tall
(make-forest mnist-n-class mnist-datamatrix mnist-target
:n-tree 100 :bagging-ratio 1.0 :max-depth 15 :n-trial 28 :min-region-samples 5))
;; 16.847 seconds (4 core)
(time
(defparameter mnist-forest-tall
(make-forest mnist-n-class mnist-datamatrix mnist-target
:n-tree 100 :bagging-ratio 1.0 :max-depth 15 :n-trial 28 :min-region-samples 5)))

;; 2.023 seconds, Accuracy: 96.62%
(test-forest mnist-forest-tall mnist-datamatrix-test mnist-target-test)
;; 1.291 seconds, Accuracy: 81.23%
(time (test-forest mnist-forest-tall mnist-datamatrix-test mnist-target-test))

;;; Global Refinement of Random Forest ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

;; Generate sparse data from Random Forest

;; 6.255 seconds (1 core), 1.809 seconds (4 core)
(defparameter mnist-refine-dataset
(make-refine-dataset mnist-forest mnist-datamatrix))
;; 3.303 seconds (4 core)
(time
(defparameter mnist-refine-dataset
(make-refine-dataset mnist-forest mnist-datamatrix)))

;; 0.995 seconds (1 core), 0.322 seconds (4 core)
(defparameter mnist-refine-test
(make-refine-dataset mnist-forest mnist-datamatrix-test))
;; 0.423 seconds (4 core)
(time
(defparameter mnist-refine-test
(make-refine-dataset mnist-forest mnist-datamatrix-test)))

(defparameter mnist-refine-learner (make-refine-learner mnist-forest))

;; 4.347 seconds (1 core), 2.281 seconds (4 core), Accuracy: 98.259%
(train-refine-learner-process mnist-refine-learner mnist-refine-dataset mnist-target
mnist-refine-test mnist-target-test)
;; 2.495 seconds (4 core), Accuracy: 90.97
(time
(train-refine-learner-process mnist-refine-learner mnist-refine-dataset mnist-target
mnist-refine-test mnist-target-test))

(test-refine-learner mnist-refine-learner mnist-refine-test mnist-target-test)

;; 5.859 seconds (1 core), 4.090 seconds (4 core), Accuracy: 98.29%
(loop repeat 5 do
;; more training
(loop repeat 10 do
(train-refine-learner mnist-refine-learner mnist-refine-dataset mnist-target)
(test-refine-learner mnist-refine-learner mnist-refine-test mnist-target-test))

Expand Down Expand Up @@ -221,4 +222,4 @@

(cross-validation-forest-with-refine-learner
n-fold mnist-n-class mnist-datamatrix mnist-target
:n-tree 100 :bagging-ratio 0.1 :max-depth 10 :n-trial 28 :gamma 10d0 :min-region-samples 5)
:n-tree 100 :bagging-ratio 0.1 :max-depth 10 :n-trial 28 :gamma 10.0 :min-region-samples 5)
40 changes: 20 additions & 20 deletions src/feature-importance.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
&key quiet-p oob-sample-indices)
(declare (optimize (speed 3) (safety 0))
(type dtree dtree)
(type (simple-array double-float) datamatrix)
(type (simple-array single-float) datamatrix)
(type (simple-array fixnum (*)) target))
(let* ((n-correct 0)
(oob-sample-indices (if (null oob-sample-indices)
Expand All @@ -43,7 +43,7 @@
(defun find-leaf-randomized (node datamatrix datum-index randomized-attribute oob-sample-indices)
(declare (optimize (speed 3) (safety 0))
(type fixnum datum-index)
(type (simple-array double-float) datamatrix)
(type (simple-array single-float) datamatrix)
(type (simple-array fixnum) oob-sample-indices))
(flet ((random-pick-oob-index ()
(aref oob-sample-indices (random (length oob-sample-indices)))))
Expand All @@ -55,7 +55,7 @@
(aref datamatrix (random-pick-oob-index) attribute)
(aref datamatrix datum-index attribute))))
(declare (type fixnum attribute)
(type double-float threshold datum))
(type single-float threshold datum))
(if (>= datum threshold)
(find-leaf-randomized (node-left-node node) datamatrix datum-index
randomized-attribute oob-sample-indices)
Expand All @@ -65,18 +65,18 @@
(defun predict-dtree-randomized (dtree datamatrix datum-index randomized-attribute oob-sample-indices)
(declare (optimize (speed 3) (safety 0))
(type dtree dtree)
(type (simple-array double-float) datamatrix)
(type (simple-array single-float) datamatrix)
(type fixnum datum-index randomized-attribute)
(type (simple-array fixnum) oob-sample-indices))
(let ((max 0d0)
(let ((max 0.0)
(max-class 0)
(dist (node-class-distribution
(find-leaf-randomized (dtree-root dtree) datamatrix datum-index
randomized-attribute oob-sample-indices)))
(n-class (dtree-n-class dtree)))
(declare (type double-float max)
(declare (type single-float max)
(type fixnum max-class n-class)
(type (simple-array double-float) dist))
(type (simple-array single-float) dist))
(loop for i fixnum from 0 to (1- n-class) do
(when (> (aref dist i) max)
(setf max (aref dist i)
Expand All @@ -87,7 +87,7 @@
&key quiet-p oob-sample-indices)
(declare (optimize (speed 3) (safety 0))
(type dtree dtree)
(type (simple-array double-float) datamatrix)
(type (simple-array single-float) datamatrix)
(type (simple-array fixnum (*)) target))
(let* ((n-correct 0)
(oob-sample-indices (if (null oob-sample-indices)
Expand Down Expand Up @@ -126,13 +126,13 @@
(defun test-rtree-oob (rtree datamatrix target &key quiet-p oob-sample-indices)
(declare (optimize (speed 3) (safety 0))
(type dtree rtree)
(type (simple-array double-float) datamatrix target))
(let* ((sum-square-error 0d0)
(type (simple-array single-float) datamatrix target))
(let* ((sum-square-error 0.0)
(oob-sample-indices (if (null oob-sample-indices)
(dtree-oob-sample-indices rtree)
oob-sample-indices))
(len-oob (length oob-sample-indices)))
(declare (type double-float sum-square-error)
(declare (type single-float sum-square-error)
(type fixnum len-oob)
(type (simple-array fixnum) oob-sample-indices))
(loop for i fixnum from 0 below len-oob do
Expand All @@ -157,14 +157,14 @@
&key quiet-p oob-sample-indices)
(declare (optimize (speed 3) (safety 0))
(type dtree rtree)
(type (simple-array double-float) datamatrix target)
(type (simple-array single-float) datamatrix target)
(type fixnum randomized-attribute))
(let* ((sum-square-error 0d0)
(let* ((sum-square-error 0.0)
(oob-sample-indices (if (null oob-sample-indices)
(dtree-oob-sample-indices rtree)
oob-sample-indices))
(len-oob (length oob-sample-indices)))
(declare (type double-float sum-square-error)
(declare (type single-float sum-square-error)
(type fixnum len-oob)
(type (simple-array fixnum) oob-sample-indices))
(loop for i fixnum from 0 below len-oob do
Expand All @@ -181,7 +181,7 @@
(let* ((oob-sample-indices (dtree-oob-sample-indices rtree))
(rms-oob (test-rtree-oob rtree datamatrix target
:quiet-p t :oob-sample-indices oob-sample-indices))
(result (make-array (dtree-datum-dim rtree) :element-type 'double-float :initial-element 0d0)))
(result (make-array (dtree-datum-dim rtree) :element-type 'single-float :initial-element 0.0)))
(loop for i from 0 below (dtree-datum-dim rtree) do
(setf (aref result i)
(- (test-rtree-oob-randomized rtree datamatrix target i :quiet-p t)
Expand All @@ -208,8 +208,8 @@

(defun dtree-feature-importance-impurity (dtree)
(let* ((dim (dtree-datum-dim dtree))
(acc-arr (clol::make-dvec dim 0d0))
(cnt-arr (clol::make-dvec dim 0d0)))
(acc-arr (clol::make-vec dim 0.0))
(cnt-arr (clol::make-vec dim 0.0)))

;; ignore root and leaf nodes
(flet ((store-decrease-impurity (node)
Expand All @@ -222,12 +222,12 @@
(- (node-information-gain node)
(+ (* (/ (node-n-sample left) len) (node-information-gain left))
(* (/ (node-n-sample right) len) (node-information-gain right)))))
(incf (aref cnt-arr attr) 1d0)))))
(incf (aref cnt-arr attr) 1.0)))))
(traverse #'store-decrease-impurity (node-left-node (dtree-root dtree)))
(traverse #'store-decrease-impurity (node-right-node (dtree-root dtree))))

(loop for i from 0 below dim do
(when (> (aref cnt-arr i) 0d0)
(when (> (aref cnt-arr i) 0.0)
(setf (aref acc-arr i) (/ (aref acc-arr i) (aref cnt-arr i)))))

(let ((min (loop for i from 0 below dim minimize (aref acc-arr i))))
Expand All @@ -237,7 +237,7 @@

(defun forest-feature-importance-impurity (forest)
(let* ((len (forest-datum-dim forest))
(result (make-array len :initial-element 0d0)))
(result (make-array len :initial-element 0.0)))
(dolist (importance-vec
(mapcar/pmapcar #'dtree-feature-importance-impurity (forest-dtree-list forest)))
(loop for i from 0 below len do
Expand Down

0 comments on commit 5713d93

Please sign in to comment.