UPDATE: 2022-12-15 20:54:40

はじめに

ここでは{FastKNN}の使い方まとめておく。{FastKNN}を使えば、高速にk最近傍分類器が作れたり、K近傍を用いた特徴量エンジニアリングが可能。通常の{knn}よりも50倍は速いそうで、Kaggleなんかでは特徴量エンジニアリングの手法として使われているK近傍を用いた特徴量エンジニアリングも可能。とりあえずこのパッケージの特徴は下記の通り。

パッケージインストール

library(remotes)
install_github("davpinto/fastknn")
library(fastknn)

{RANN}{foreach}{Metrics}{matrixStats}{ggplot2}{viridis}などのパッケージに依存しているみたいです。{RANN}{foreach}があるように高速化、並列化することで、大きなサイズのデータでも高速に動作するパッケージであることがわかる。

K近傍を用いた特徴量エンジニアリング

{fastknn}を使うことで、K近傍を用いた特徴量エンジニアリングが可能。ざっくり説明すると、クラスラベルの数cと近傍数kの間の距離をもとにk×c個の特徴量を生成。なので、クラス数c=3で、近傍数をk=5とすると、15個の特徴量が生成される。

イメージとしては、あるクラスに属する訓練データの最近傍までの距離を1つ目の特徴量として、第2近傍までの距離の和(最近傍までの距離+第2近傍までの距離)を2つ目の特徴量として……という感じでこれをkに関して繰り返すことで特徴量を作る。

具体的な計算イメージについては、こちらの記事に詳しくのっている。

ここでは、{fastknn}のGithub上の例を再現する。この例では、K近傍を用いた特徴量エンジニアリングを行うことで、10%もAccuracyが向上している。

library(mlbench)
library(caTools)
library(glmnet)

# Load data
data("Ionosphere", package = "mlbench")
x <- data.matrix(subset(Ionosphere, select = -Class))
y <- Ionosphere$Class

# Remove near zero variance columns
x <- x[, -c(1,2)]

# Split data
set.seed(123)
tr.idx <- which(sample.split(Y = y, SplitRatio = 0.7))
x.tr <- x[tr.idx,]
x.te <- x[-tr.idx,]
y.tr <- y[tr.idx]
y.te <- y[-tr.idx]

# GLM with original features
glm <- glmnet(x = x.tr, y = y.tr, family = "binomial", lambda = 0)
yhat <- drop(predict(glm, x.te, type = "class"))
yhat1 <- factor(yhat, levels = levels(y.tr))

# Generate KNN features
set.seed(123)
new.data <- knnExtract(xtr = x.tr, ytr = y.tr, xte = x.te, k = 3)
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |=======                                                               |  10%
  |                                                                            
  |==============                                                        |  20%
  |                                                                            
  |=====================                                                 |  30%
  |                                                                            
  |============================                                          |  40%
  |                                                                            
  |===================================                                   |  50%
  |                                                                            
  |==========================================                            |  60%
  |                                                                            
  |=================================================                     |  70%
  |                                                                            
  |========================================================              |  80%
  |                                                                            
  |===============================================================       |  90%
  |                                                                            
  |======================================================================| 100%
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |===================================                                   |  50%
  |                                                                            
  |======================================================================| 100%
# GLM with KNN features
glm <- glmnet(x = new.data$new.tr, y = y.tr, family = "binomial", lambda = 0)
yhat <- drop(predict(glm, new.data$new.te, type = "class"))
yhat2 <- factor(yhat, levels = levels(y.tr))

# Performance
list(
sprintf("Accuracy: %.2f", 100 * (1 - classLoss(actual = y.te, predicted = yhat1))),
sprintf("Accuracy: %.2f", 100 * (1 - classLoss(actual = y.te, predicted = yhat2)))
)
## [[1]]
## [1] "Accuracy: 83.81"
## 
## [[2]]
## [1] "Accuracy: 96.19"

K近傍を用いた特徴量エンジニアリングを行っている部分はここ。kには近傍の点としていくつ使用するかを設定。

knnExtract(xtr = x.tr, ytr = y.tr, xte = x.te, k = 3)
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |=======                                                               |  10%
  |                                                                            
  |==============                                                        |  20%
  |                                                                            
  |=====================                                                 |  30%
  |                                                                            
  |============================                                          |  40%
  |                                                                            
  |===================================                                   |  50%
  |                                                                            
  |==========================================                            |  60%
  |                                                                            
  |=================================================                     |  70%
  |                                                                            
  |========================================================              |  80%
  |                                                                            
  |===============================================================       |  90%
  |                                                                            
  |======================================================================| 100%
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |===================================                                   |  50%
  |                                                                            
  |======================================================================| 100%
## $new.tr
##            knn1      knn2      knn3     knn4      knn5      knn6
##   [1,] 1.499212  3.811036  6.275441 0.904031  2.150202  3.437391
##   [2,] 2.783883  5.664107  8.565420 2.618879  5.395238  8.385702
##   [3,] 3.789290  7.609333 11.484334 3.764956  7.551773 11.382875
##   [4,] 1.965208  4.260902  6.907153 1.269413  2.646760  4.329151
##   [5,] 0.440177  0.921218  1.585418 0.956144  2.448385  4.278621
##   [6,] 2.422808  5.555677  8.738436 0.671978  1.418216  2.544402
##   [7,] 3.350669  6.757612 10.364046 2.791661  5.613947  8.571800
##   [8,] 1.855603  4.495724  7.294386 0.612756  1.245848  2.174540
##   [9,] 2.401206  5.359756  8.533233 2.078666  4.286245  6.499967
##  [10,] 2.901313  6.114812  9.482911 3.259588  6.634625 10.069889
##  [11,] 1.707760  3.439244  5.194714 0.381040  0.844549  1.311158
##  [12,] 5.291503 10.637882 16.023047 5.484870 10.979314 16.498124
##  [13,] 1.916580  3.894994  5.909458 1.158238  2.388691  3.635843
##  [14,] 4.424485  9.149921 13.883227 4.632771  9.331058 14.035884
##  [15,] 4.242641  8.635833 13.064419 4.444464  8.947483 13.671855
##  [16,] 1.204678  2.985110  5.092694 0.760031  1.532056  2.304442
##  [17,] 4.358899  8.828623 13.318863 4.489055  8.988832 13.629498
##  [18,] 3.084800  6.299565  9.647767 1.663652  3.390426  5.168736
##  [19,] 3.958525  8.297961 12.681008 4.379970  8.918652 13.589905
##  [20,] 3.061464  6.198125  9.365819 1.515396  3.514206  5.601439
##  [21,] 3.389241  7.051749 10.855521 3.436296  7.107379 10.823824
##  [22,] 1.532953  3.951040  6.536600 0.633092  1.643387  2.709468
##  [23,] 2.566616  5.293893  8.545345 2.518154  5.045730  7.593416
##  [24,] 2.457676  4.950765  7.444100 1.536650  3.152168  4.773618
##  [25,] 3.249752  6.555965  9.945207 2.457676  5.432794  8.529561
##  [26,] 1.614598  3.616030  5.849119 0.259826  0.562458  0.886268
##  [27,] 4.524164  9.106739 13.728169 4.637076  9.317687 14.001870
##  [28,] 2.199965  4.478010  6.809966 1.613360  3.281711  4.967512
##  [29,] 3.975151  8.392040 12.831040 3.974700  7.959252 11.946199
##  [30,] 2.031132  4.867018  7.975863 0.612756  1.323881  2.070119
##  [31,] 2.042406  4.294291  6.584124 0.302632  0.757682  1.214640
##  [32,] 4.511498  9.043556 13.588784 4.507021  9.031219 13.566130
##  [33,] 4.812509  9.627932 14.492889 1.666803  3.594224  5.556686
##  [34,] 5.049907 10.197874 15.395440 5.276922 10.664983 16.059035
##  [35,] 1.734598  4.021697  6.415691 0.545419  1.115800  1.819894
##  [36,] 4.350193  8.813844 13.339522 4.407119  8.873614 13.347612
##  [37,] 2.187218  4.505066  7.068442 1.081943  2.169706  3.264379
##  [38,] 4.701868  9.560995 14.431123 5.036908 10.123978 15.244749
##  [39,] 1.668351  3.504240  5.552769 0.473098  1.078007  1.698618
##  [40,] 3.480149  7.057499 10.636290 3.558466  7.253341 11.012387
##  [41,] 3.798605  7.881700 12.053279 3.947151  7.915632 11.891460
##  [42,] 2.331956  4.733163  7.136185 1.594107  3.338599  5.088729
##  [43,] 3.757415  7.779988 11.961405 3.525040  7.247427 10.979552
##  [44,] 3.806960  7.732874 11.813373 1.413274  2.839242  4.348947
##  [45,] 2.230501  4.545397  7.045792 0.897131  2.033256  3.517664
##  [46,] 2.101038  4.617595  7.141716 0.311146  1.357237  2.490891
##  [47,] 3.757415  8.283894 12.819475 4.049118  8.215323 12.383561
##  [48,] 1.955471  4.306125  6.699599 0.311146  1.208277  2.124683
##  [49,] 3.893008  7.825055 11.838594 1.908329  3.954371  6.120147
##  [50,] 4.549312  9.357814 14.173338 4.679054  9.383191 14.118504
##  [51,] 2.695647  6.025391  9.355599 1.536565  3.327661  5.168167
##  [52,] 4.000000  8.224369 12.541716 3.794580  7.749902 11.715722
##  [53,] 2.289460  4.588972  7.137208 0.244105  1.306350  2.483612
##  [54,] 1.940924  3.887418  5.888743 1.250046  2.899385  4.683825
##  [55,] 1.349565  3.314243  5.306548 0.684381  1.436445  2.216438
##  [56,] 2.215582  4.529339  6.996583 1.379817  2.987835  4.714021
##  [57,] 1.932302  3.955817  5.982529 0.590428  1.194860  1.811334
##  [58,] 1.745819  3.830361  6.030325 1.754834  3.566512  5.483930
##  [59,] 1.526006  3.076644  4.891834 0.647032  1.296004  1.997979
##  [60,] 1.862526  3.734972  5.620979 1.553494  3.196585  4.868816
##  [61,] 1.839185  4.119341  6.401035 0.916406  1.962497  3.098623
##  [62,] 1.420679  3.028696  4.965038 1.073582  2.156924  3.279676
##  [63,] 2.289122  4.636937  7.053381 2.248552  4.654121  7.139352
##  [64,] 1.687962  3.383957  5.311368 0.738440  1.591774  2.505606
##  [65,] 2.639144  5.481274  8.460349 2.014182  4.066688  6.179005
##  [66,] 2.014182  4.667554  7.987471 0.440032  1.028027  2.577911
##  [67,] 3.670783  7.530255 11.705520 4.216831  8.514090 12.813803
##  [68,] 1.471258  3.197445  5.033216 0.519965  1.048615  1.638350
##  [69,] 1.794177  4.132391  6.472115 0.259826  0.567981  1.023031
##  [70,] 2.211604  4.671259  7.456633 1.287758  2.587380  3.912432
##  [71,] 1.834423  3.911023  6.015644 2.132846  4.279335  6.447479
##  [72,] 1.588777  3.593691  5.775469 0.442459  0.890205  1.360378
##  [73,] 2.624771  5.408002  8.201927 2.708713  5.426128  8.177069
##  [74,] 3.053554  6.509842 10.111682 1.556054  3.117903  4.681982
##  [75,] 2.846734  5.768837  8.698076 2.840344  5.710767  8.594170
##  [76,] 2.203024  4.523574  7.568293 0.533672  1.519523  3.112161
##  [77,] 1.916490  3.871182  5.850929 1.729706  3.584881  5.591598
##  [78,] 1.842842  4.058424  6.336469 0.943928  2.148606  3.361832
##  [79,] 1.520742  3.134102  4.995431 0.589734  1.207223  1.836286
##  [80,] 2.740997  5.520145  8.305942 2.883997  5.827471  8.795458
##  [81,] 1.750824  3.830735  6.119786 0.485719  1.034946  1.592696
##  [82,] 1.755879  3.538998  5.365209 1.279039  2.627672  4.035656
##  [83,] 2.366946  4.937500  7.645484 2.020901  4.048979  6.100426
##  [84,] 2.369979  4.792424  7.269437 2.500908  5.053145  7.706163
##  [85,] 1.532656  3.237419  4.942347 0.764403  1.611661  2.530225
##  [86,] 3.214270  6.580504  9.947387 3.464822  6.940317 10.482913
##  [87,] 1.809793  4.267858  6.919731 1.245673  2.558813  3.966737
##  [88,] 2.084541  4.673667  7.612763 1.899938  3.923574  5.979928
##  [89,] 2.574554  5.336496  8.235501 3.038458  6.081407  9.128355
##  [90,] 1.894711  3.996549  6.175435 0.169320  0.352994  0.581478
##  [91,] 2.467244  5.040761  7.629516 1.856973  3.753032  5.659371
##  [92,] 2.603570  5.918458  9.585582 3.274054  6.668097 10.075728
##  [93,] 2.588755  5.192325  8.332746 2.877865  5.807563  8.852930
##  [94,] 3.701209  7.420018 11.157600 3.572884  7.194390 10.847965
##  [95,] 1.279039  3.170182  5.139053 0.736923  1.519773  2.337088
##  [96,] 1.847021  3.727967  5.758757 0.183674  0.400418  0.662554
##  [97,] 0.704039  1.453828  2.371767 0.346984  1.686507  3.163356
##  [98,] 1.395390  2.843876  4.399330 1.626561  3.272374  5.348607
##  [99,] 1.809731  3.766518  5.955126 0.208062  0.447401  0.722987
## [100,] 1.745819  4.585132  7.473555 2.083130  4.275814  6.545754
## [101,] 1.920026  3.893154  6.044928 0.214871  0.454130  0.711185
## [102,] 1.438045  3.183281  5.153077 0.528650  1.062882  1.642197
## [103,] 2.749962  5.589783  8.544605 0.303423  0.721809  1.361294
## [104,] 0.268911  0.749952  1.331815 0.678115  1.427904  2.789993
## [105,] 1.443619  3.283168  5.257796 0.371895  0.762041  1.190028
## [106,] 1.687059  3.379268  5.159453 1.287601  2.627124  3.974240
## [107,] 2.970017  5.970467  8.980932 0.773456  1.576433  2.464600
## [108,] 0.268911  0.709088  1.218046 0.645495  1.349534  2.773685
## [109,] 1.833892  3.827118  5.872611 0.326779  0.705388  1.133376
## [110,] 5.126666 10.461603 16.017264 5.434846 10.901009 16.374127
## [111,] 3.207512  6.533095  9.862595 0.505720  1.091830  1.865286
## [112,] 4.120331  8.300093 12.534951 3.505260  7.014236 10.545820
## [113,] 4.380800  8.935059 13.536903 4.647902  9.476586 14.451819
## [114,] 3.712971  7.456893 11.270518 1.338841  2.853033  4.516773
## [115,] 0.875211  1.836063  2.826636 1.220794  2.914500  4.672483
## [116,] 1.811678  3.710840  5.807039 1.058131  2.176397  3.304268
## [117,] 4.472136  8.944272 13.445594 4.693292  9.405741 14.132428
## [118,] 3.092329  6.246620  9.496620 3.502867  7.270298 11.100154
## [119,] 1.921750  3.881214  5.842738 1.865918  3.851665  5.992768
## [120,] 1.674501  3.374431  5.517549 0.631736  1.282401  1.953303
## [121,] 0.508959  1.090822  1.699880 0.917939  2.500160  4.101482
## [122,] 2.317581  5.440278  8.653179 0.640621  1.351746  2.417827
## [123,] 3.098763  6.519626  9.945663 3.221813  6.529429  9.893219
## [124,] 2.313854  4.932733  7.699907 0.894067  1.798098  2.909054
## [125,] 1.982150  4.650657  7.527162 0.894067  2.005760  3.292948
## [126,] 3.332652  6.748456 10.165024 0.285386  0.871496  1.476007
## [127,] 4.648549  9.444381 14.343703 4.616060  9.387659 14.165707
## [128,] 4.403215  9.074810 13.825606 1.975194  4.123561  6.453436
## [129,] 4.827309  9.678261 14.532000 4.813270  9.857905 15.027701
## [130,] 2.886871  5.828488  8.925124 1.851616  3.995108  6.147674
## [131,] 1.919301  3.906573  6.093540 1.111300  2.285433  3.464852
## [132,] 4.307671  9.170495 14.125648 5.043563 10.167998 15.305643
## [133,] 1.594107  3.354038  5.561618 0.837265  1.678029  2.525866
## [134,] 3.186087  6.682361 10.265516 3.545373  7.153227 10.770946
## [135,] 1.757314  3.826591  6.261570 0.942854  1.901345  2.862973
## [136,] 1.635879  3.467133  5.575560 0.702781  1.415868  2.151507
## [137,] 3.395305  7.045257 10.796128 3.279261  6.591311  9.966172
## [138,] 3.031877  6.103727  9.233557 0.640621  1.312599  2.281504
## [139,] 4.072047  8.196104 12.329645 4.141580  8.440270 12.865381
## [140,] 1.860213  3.782066  5.707965 0.403579  0.808793  1.233832
## [141,] 4.792819  9.703543 14.661062 4.814803  9.689120 14.729599
## [142,] 1.750556  3.611315  5.696573 0.772330  1.608378  2.472794
## [143,] 2.044389  4.134562  6.348798 2.028107  4.150237  6.625649
## [144,] 2.747234  5.572044  8.397719 2.759541  5.656387  8.713926
## [145,] 2.023636  4.346626  6.676206 1.310897  2.752359  4.255915
## [146,] 3.214289  7.009629 11.142737 3.677880  7.518290 11.367495
## [147,] 3.166633  6.453424  9.745949 0.285386  0.857848  1.746015
## [148,] 2.030751  4.140003  6.333252 0.800098  1.827701  2.939001
## [149,] 4.888999  9.846953 15.081736 4.998273 10.407487 15.841585
## [150,] 1.151338  2.401384  4.513701 1.032419  2.176813  3.613496
## [151,] 4.038085  8.113504 12.530206 3.703740  7.459373 11.322497
## [152,] 2.403818  4.819060  7.645223 0.244105  1.044203  2.147196
## [153,] 2.828427  5.799964  8.839793 3.041812  6.224448  9.435744
## [154,] 3.605551  7.283797 10.976817 3.353205  6.882574 10.510805
## [155,] 3.343109  6.867130 10.867130 3.646407  7.512272 11.446195
## [156,] 2.061406  4.264073  6.690072 1.315035  2.642636  4.096160
## [157,] 1.926372  3.922681  6.018666 0.187687  0.379075  0.587137
## [158,] 3.554145  7.554145 11.677250 4.403723  8.876754 13.369081
## [159,] 2.403022  4.957605  7.599534 1.635879  3.310380  5.067694
## [160,] 2.011412  4.459937  6.965216 1.481974  3.041397  4.611445
## [161,] 1.272875  2.567703  3.969705 1.088463  2.239801  3.940514
## [162,] 1.321787  3.217847  5.199537 0.454505  0.940106  1.530399
## [163,] 0.550217  1.159276  1.807008 0.872150  1.803572  3.441898
## [164,] 1.424546  3.270694  5.198212 0.505067  1.098019  1.702451
## [165,] 1.985781  3.983073  6.001077 2.235885  4.743696  7.326321
## [166,] 1.558077  3.634664  5.826804 0.203863  0.607382  1.025781
## [167,] 2.362348  4.779125  7.202170 2.368351  4.809221  7.401535
## [168,] 1.949156  3.968188  6.148749 0.203863  0.560863  0.921898
## [169,] 1.408956  3.318355  5.333292 0.326779  0.687813  1.059708
## [170,] 2.316879  4.636863  6.980797 2.348149  4.822974  7.437841
## [171,] 1.387325  2.860961  4.705418 0.546930  1.118677  1.708971
## [172,] 1.000000  2.305870  3.705397 1.568706  3.169677  5.199597
## [173,] 3.607263  7.350570 11.173985 3.694409  7.420349 11.221919
## [174,] 3.041687  6.126896  9.229572 0.375009  0.871641  1.529678
## [175,] 2.587780  5.423518  8.265167 2.896092  5.827601  8.847879
## [176,] 3.341100  6.732035 10.127414 0.922069  1.989280  3.123020
## [177,] 2.068941  4.167908  6.312092 0.339661  2.189523  4.130580
## [178,] 1.361753  2.976892  4.908946 0.407843  0.893443  1.472758
## [179,] 1.351070  3.036871  4.794592 0.519965  1.054197  1.738552
## [180,] 1.524550  3.125873  4.826585 0.647032  1.384479  2.125070
## [181,] 1.748969  3.533409  5.997472 0.412419  1.444838  2.507084
## [182,] 2.243045  4.583327  7.328751 0.285815  0.873810  1.830075
## [183,] 1.553494  3.562526  5.571915 1.512496  3.098253  4.762290
## [184,] 1.868879  3.811942  5.869268 1.231434  2.472940  3.767567
## [185,] 1.612214  3.352383  5.464551 0.921570  1.851893  2.791365
## [186,] 0.943928  2.641647  4.563981 0.694702  1.431280  2.183398
## [187,] 1.348633  3.071334  5.062743 1.051923  2.158072  3.390504
## [188,] 1.379817  3.134651  4.969052 0.894931  1.910772  2.933807
## [189,] 1.253376  2.887425  4.626628 0.407843  0.862349  1.409279
## [190,] 1.241100  2.782457  4.666688 0.696319  1.402552  2.154670
## [191,] 1.407984  3.046310  4.707052 0.760386  1.524789  2.357913
## [192,] 1.849384  3.840386  6.336089 0.412419  1.556813  2.734075
## [193,] 2.052506  4.950661  8.323472 0.285815  1.404871  2.757741
## [194,] 2.368640  4.775417  8.005309 0.533369  1.755902  3.304724
## [195,] 2.273261  4.858746  7.611217 0.440032  0.973401  1.507073
## [196,] 2.462507  5.758752  9.093211 0.705633  2.422034  4.266291
## [197,] 2.489473  5.213768  7.955896 0.735439  1.483081  2.413797
## [198,] 0.645495  1.323610  2.195761 0.346984  1.634585  3.085824
## [199,] 1.245820  3.454254  5.727152 1.105039  2.233199  3.418229
## [200,] 1.573838  3.248824  5.003229 0.937082  1.911359  2.898631
## [201,] 1.500305  3.287477  5.139668 0.278391  0.668536  1.068857
## [202,] 1.768204  3.668142  5.585560 0.751584  1.504643  2.262459
## [203,] 1.563454  3.396668  5.365294 0.494735  1.022330  1.559028
## [204,] 1.551626  3.193597  5.167992 0.841774  1.695228  2.681226
## [205,] 1.264183  3.114796  5.171634 0.809815  1.624184  2.493586
## [206,] 2.018008  4.067681  6.200176 0.514246  1.068867  1.652285
## [207,] 1.243221  3.258020  5.327715 0.827986  1.671037  2.514811
## [208,] 1.241482  3.098454  5.176593 0.571747  1.230230  1.920896
## [209,] 2.486276  5.698474  9.069156 0.766841  2.601877  4.505772
## [210,] 2.865606  6.458141 10.095661 0.705633  1.472474  2.849986
## [211,] 3.036847  6.718727 10.445353 0.554981  1.448942  2.792767
## [212,] 2.631324  5.479504  8.482468 0.457155  1.211905  1.989927
## [213,] 1.929964  4.162310  6.418376 0.308155  0.631964  0.966301
## [214,] 1.849524  3.910371  6.061225 0.216744  0.445228  0.732569
## [215,] 1.405683  3.312021  5.405735 0.917703  1.883087  2.929028
## [216,] 3.282555  6.817894 10.359958 1.336609  2.885430  4.435315
## [217,] 2.625346  5.679654  8.838792 0.457155  1.263045  2.167940
## [218,] 3.252044  6.713174 10.470768 0.554981  1.428452  2.333347
## [219,] 3.428851  7.070237 10.734459 0.873471  1.766926  2.665121
## [220,] 1.929022  3.908316  5.962047 0.339362  0.705277  1.118583
## [221,] 1.787307  3.667567  5.913465 0.899745  1.804853  2.737481
## [222,] 3.610332  7.225419 10.845988 0.608634  1.266671  2.018455
## [223,] 1.453397  3.391603  5.448770 0.802622  1.630700  2.536557
## [224,] 1.750129  3.590354  5.536353 0.294335  0.626180  1.005512
## [225,] 1.701212  3.507705  5.434989 0.271657  0.558998  0.853332
## [226,] 1.744492  3.581543  5.590675 0.244740  0.528963  0.841172
## [227,] 1.602881  3.220384  5.110840 0.740591  1.652802  2.567304
## [228,] 2.717229  5.643984  8.637350 0.303423  0.761679  1.425074
## [229,] 3.134321  6.326451  9.589788 0.375009  0.983643  1.623127
## [230,] 3.701493  7.430669 11.162984 0.288626  1.202057  2.212871
## [231,] 2.907473  5.833603  8.774497 1.142005  2.444078  3.764831
## [232,] 1.778153  3.653317  5.688541 0.216411  0.461151  0.718206
## [233,] 1.808694  3.690366  5.600386 0.381329  0.794404  1.263060
## [234,] 1.511068  3.334310  5.235083 0.278391  0.593327  0.925172
## [235,] 1.784692  3.741594  5.706867 0.138235  0.304487  0.492174
## [236,] 1.830320  3.756445  5.777850 0.166251  0.335571  0.526959
## [237,] 1.777055  3.816234  5.910653 0.138235  0.357556  0.605133
## [238,] 3.012846  6.026943  9.166485 0.418387  0.876643  1.373275
## [239,] 3.681214  7.403435 11.148038 0.288626  1.040410  1.962479
## [240,] 2.201242  4.430946  6.675684 0.339661  2.031348  3.784234
## [241,] 3.208735  6.447590  9.701949 0.505720  1.078182  1.682693
## [242,] 1.362089  2.786240  4.278482 1.451239  2.928088  5.154193
## [243,] 1.355011  3.148909  5.264396 0.463420  0.945154  1.446772
## [244,] 1.213227  2.688523  4.530180 0.690667  1.427462  2.183162
## [245,] 1.781338  3.693575  5.649253 0.460792  0.934080  1.418898
## [246,] 1.523231  3.389075  5.327649 0.398322  0.821500  1.259788
## 
## $new.te
##            knn1      knn2      knn3     knn4      knn5      knn6
##   [1,] 1.259176  3.476450  5.704539 0.846399  1.753174  2.669671
##   [2,] 2.450065  5.203995  8.333969 0.779574  1.603266  2.714035
##   [3,] 3.527862  7.133414 10.752424 3.553530  7.270805 11.037254
##   [4,] 2.045341  4.849451  7.768165 0.549818  1.232183  1.995172
##   [5,] 1.257051  2.527908  3.838505 1.538030  3.076789  4.935496
##   [6,] 1.748976  3.647918  5.612912 0.696067  1.596946  2.556994
##   [7,] 1.691148  4.097347  6.635406 0.491712  1.164605  1.887064
##   [8,] 1.331121  3.592245  5.931407 0.562907  1.303161  2.063949
##   [9,] 3.884835  7.770569 12.013044 2.044692  4.300826  6.623006
##  [10,] 5.103691 10.287515 15.482479 5.167132 10.414203 15.729443
##  [11,] 2.286159  5.066758  7.962681 0.791468  1.762724  2.901213
##  [12,] 1.632233  3.267679  5.045139 0.252950  0.728110  1.205764
##  [13,] 2.374229  4.808646  7.318464 0.352957  0.773066  1.370947
##  [14,] 4.739140  9.530992 14.333119 4.939948  9.959111 15.007664
##  [15,] 1.778888  3.686985  6.074023 0.947418  1.991178  3.173815
##  [16,] 3.480695  6.992928 10.648341 3.588557  7.215404 10.928286
##  [17,] 1.664535  4.147615  6.632443 0.388113  0.976277  1.657157
##  [18,] 3.105475  6.248747  9.398357 3.062351  6.301478  9.583431
##  [19,] 2.901638  5.970759  9.192714 3.188065  6.413571  9.691254
##  [20,] 2.017280  4.062272  6.376634 1.449904  2.950466  4.476422
##  [21,] 1.605867  3.648533  5.701770 0.681692  1.526544  2.379858
##  [22,] 2.360507  4.831404  7.404370 2.789478  5.597483  8.410647
##  [23,] 2.075752  4.562276  7.104115 0.610579  1.251616  1.983764
##  [24,] 1.967423  3.999573  6.335826 1.336801  2.695724  4.063890
##  [25,] 1.654850  3.465953  5.359452 0.685711  1.390004  2.190512
##  [26,] 3.660483  7.591608 11.530099 3.629142  7.261113 11.012736
##  [27,] 4.546272  9.216989 13.915976 4.604260  9.265591 13.959526
##  [28,] 2.509929  5.283448  8.092577 0.568297  1.176123  2.097319
##  [29,] 1.820272  3.732000  5.652473 1.882274  3.832548  5.831149
##  [30,] 4.690416  9.743040 14.815972 4.608398  9.238433 13.913276
##  [31,] 2.457443  5.014674  7.708328 2.646050  5.343929  8.115279
##  [32,] 2.420151  5.070191  7.913428 0.570090  1.202559  1.950040
##  [33,] 3.726258  7.478645 11.488293 3.376737  6.957617 10.615724
##  [34,] 0.000000  1.000000  2.291307 1.568706  3.169677  5.199597
##  [35,] 1.757484  4.117776  6.544462 0.616331  1.286063  1.969647
##  [36,] 2.352704  4.713685  7.190194 2.384149  4.777282  7.176218
##  [37,] 2.323453  5.002655  7.682328 2.448126  5.089015  7.737147
##  [38,] 2.490764  5.149147  7.847956 0.491290  1.248830  2.136450
##  [39,] 1.071048  2.169716  3.293542 0.870779  1.770648  2.772338
##  [40,] 1.803454  3.718810  5.799789 0.225815  0.463069  0.731653
##  [41,] 1.247438  2.533695  3.825218 0.695742  1.491567  2.341604
##  [42,] 1.543770  3.280494  5.106184 1.186972  2.411270  3.648075
##  [43,] 2.370412  5.123271  7.972136 0.764704  1.537884  2.398712
##  [44,] 0.654749  1.314412  2.014412 1.019520  2.119946  3.650722
##  [45,] 1.074360  3.000004  5.348676 1.186383  2.449506  3.718377
##  [46,] 2.865757  6.678948 10.541086 3.913940  7.950688 12.003415
##  [47,] 2.402481  5.175145  8.063962 1.848896  3.753394  5.700377
##  [48,] 0.466678  1.012914  1.598161 0.700931  1.434401  2.938944
##  [49,] 0.951504  2.669784  4.574470 0.662197  1.326179  2.090746
##  [50,] 1.657950  3.370372  5.120902 1.620724  3.247199  5.315266
##  [51,] 0.284857  0.578432  1.079464 0.712752  1.460628  2.926432
##  [52,] 0.195534  0.465710  0.860273 0.726132  1.494154  2.842262
##  [53,] 0.679107  1.430198  2.192113 0.950510  1.915097  3.091411
##  [54,] 3.760021  7.664824 11.606053 0.725539  1.559911  2.406528
##  [55,] 1.624870  3.641013  5.813302 0.645041  1.307307  1.972541
##  [56,] 1.947403  4.132465  6.573990 0.793890  1.756538  2.720096
##  [57,] 1.811310  3.838102  5.874535 0.897951  1.799094  2.731107
##  [58,] 3.259693  6.562180  9.904394 3.517073  7.081854 10.667818
##  [59,] 4.077413  8.320053 12.605327 4.473436  8.977567 13.498481
##  [60,] 4.295053  8.711417 13.146513 4.296550  8.651754 13.008685
##  [61,] 3.162278  6.429138  9.745762 3.499368  7.120361 10.776956
##  [62,] 1.407958  2.819107  4.251652 1.479185  2.991046  4.615621
##  [63,] 2.434606  4.891622  7.356787 2.531766  5.096158  7.928568
##  [64,] 3.130977  6.264354  9.420325 3.164673  6.541987  9.971182
##  [65,] 4.448864  8.995691 13.587312 4.320847  8.756432 13.214041
##  [66,] 1.804666  3.619036  5.520819 0.419374  0.877017  1.350255
##  [67,] 1.839882  4.163700  6.571781 1.348266  2.826946  4.443573
##  [68,] 4.219478  8.491480 12.793643 4.255667  8.757658 13.277545
##  [69,] 1.541335  3.128555  4.824799 0.632277  1.418425  2.225293
##  [70,] 3.327996  7.525232 11.988035 3.337183  7.158993 11.079670
##  [71,] 3.671625  7.793818 11.916924 3.860182  7.871780 11.884488
##  [72,] 2.345121  4.707843  7.326072 0.620033  1.325817  2.483277
##  [73,] 2.025595  4.073841  6.164783 0.739719  1.494670  2.356965
##  [74,] 2.959088  5.930332  8.913019 0.284736  0.699475  1.116664
##  [75,] 1.000000  2.365472  3.781026 1.587400  3.195675  4.997563
##  [76,] 1.506415  3.223800  5.012123 0.567952  1.163614  1.816533
##  [77,] 1.678285  4.091492  6.852111 2.129290  4.347984  6.585648
##  [78,] 2.142011  4.333947  6.614111 1.325653  2.691074  4.135468
##  [79,] 3.390118  6.800136 10.230262 0.243295  0.695918  1.290229
##  [80,] 1.530119  3.359071  5.208312 0.806376  1.647887  2.496932
##  [81,] 1.349612  3.197348  5.076960 0.380287  0.831554  1.303384
##  [82,] 2.195056  4.491538  6.994923 0.458153  1.120481  2.385080
##  [83,] 1.943553  3.946067  5.993679 0.408648  0.858888  1.763736
##  [84,] 2.188479  4.773859  7.409807 0.613081  1.396531  2.689519
##  [85,] 1.823317  3.660593  5.709756 0.233449  0.470273  0.718708
##  [86,] 1.303091  3.096806  4.980598 0.449050  1.002041  1.628370
##  [87,] 1.201109  2.824740  4.599784 0.425176  0.890645  1.401754
##  [88,] 1.199569  3.252980  5.390863 0.653994  1.328022  2.011626
##  [89,] 2.843196  6.150032  9.651075 0.878048  2.057569  3.272024
##  [90,] 2.408190  5.203711  8.161786 0.496489  1.097706  1.926019
##  [91,] 1.752692  3.604824  5.717658 0.413913  0.874845  1.446191
##  [92,] 1.646037  3.299394  4.956520 1.026958  2.090481  3.185734
##  [93,] 2.565718  5.525619  8.545564 0.733607  1.471951  2.722300
##  [94,] 2.363888  4.943639  7.601278 0.542282  1.098660  1.748608
##  [95,] 3.024575  6.099569  9.367940 0.449445  0.998845  1.576605
##  [96,] 1.371566  3.086779  4.902738 0.409795  0.885912  1.428292
##  [97,] 1.897994  3.806492  5.964578 0.099965  0.324300  0.551452
##  [98,] 3.403087  6.827563 10.267197 1.574755  3.199823  4.878318
##  [99,] 1.701264  3.480447  5.271300 0.336330  0.706120  1.112176
## [100,] 1.774490  3.703645  5.657161 0.132606  0.297545  0.462613
## [101,] 1.350482  3.147555  4.986636 0.331664  0.731011  1.137582
## [102,] 1.611081  3.558301  5.601278 1.049712  2.477103  3.989761
## [103,] 1.750953  3.585330  5.509206 0.214947  0.447989  0.683492
## [104,] 1.759826  3.556869  5.512745 0.184236  0.397577  0.632344
## [105,] 1.632925  3.430550  5.534143 0.304634  0.661185  1.063553

アウトプットには訓練データとテストデータ用の新しい特徴量が生成される。

str(new.data)
## List of 2
##  $ new.tr: num [1:246, 1:6] 1.5 2.78 3.78 1.97 0.44 ...
##   ..- attr(*, "dimnames")=List of 2
##   .. ..$ : NULL
##   .. ..$ : chr [1:6] "knn1" "knn2" "knn3" "knn4" ...
##  $ new.te: num [1:105, 1:6] 1.26 2.45 3.53 2.05 1.26 ...
##   ..- attr(*, "dimnames")=List of 2
##   .. ..$ : NULL
##   .. ..$ : chr [1:6] "knn1" "knn2" "knn3" "knn4" ...

ナニヲシテイルノカ?

K近傍を用いた特徴量エンジニアリングは何をやっているのか…。ざっくり説明すると、KNNは元の空間の非線形写像を作り、それを線形のものに射影することで、クラスを線形に分離できる特徴量を生成しているとのことです。

library("caTools")
library("fastknn")
library("ggplot2")
library("gridExtra")

# Load data
data("chess")
x <- data.matrix(chess$x)
y <- chess$y

# Split data
set.seed(123)
tr.idx <- which(sample.split(Y = y, SplitRatio = 0.7))
x.tr <- x[tr.idx,]
x.te <- x[-tr.idx,]
y.tr <- y[tr.idx]
y.te <- y[-tr.idx]

# Feature extraction with KNN
set.seed(123)
new.data <- knnExtract(x.tr, y.tr, x.te, k = 1)
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |=======                                                               |  10%
  |                                                                            
  |==============                                                        |  20%
  |                                                                            
  |=====================                                                 |  30%
  |                                                                            
  |============================                                          |  40%
  |                                                                            
  |===================================                                   |  50%
  |                                                                            
  |==========================================                            |  60%
  |                                                                            
  |=================================================                     |  70%
  |                                                                            
  |========================================================              |  80%
  |                                                                            
  |===============================================================       |  90%
  |                                                                            
  |======================================================================| 100%
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |===================================                                   |  50%
  |                                                                            
  |======================================================================| 100%
# Decision boundaries
g1 <- knnDecision(x.tr, y.tr, x.te, y.te, k = 10) +
  labs(title = "Original Features")
g2 <- knnDecision(new.data$new.tr, y.tr, new.data$new.te, y.te, k = 10) +
  labs(title = "KNN Features")
grid.arrange(g1, g2, ncol = 2)

# Load data
data("spirals")
x <- data.matrix(spirals$x)
y <- spirals$y

# Split data
set.seed(123)
tr.idx <- which(sample.split(Y = y, SplitRatio = 0.7))
x.tr <- x[tr.idx,]
x.te <- x[-tr.idx,]
y.tr <- y[tr.idx]
y.te <- y[-tr.idx]

# Feature extraction with KNN
set.seed(123)
new.data <- knnExtract(x.tr, y.tr, x.te, k = 1)
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |=======                                                               |  10%
  |                                                                            
  |==============                                                        |  20%
  |                                                                            
  |=====================                                                 |  30%
  |                                                                            
  |============================                                          |  40%
  |                                                                            
  |===================================                                   |  50%
  |                                                                            
  |==========================================                            |  60%
  |                                                                            
  |=================================================                     |  70%
  |                                                                            
  |========================================================              |  80%
  |                                                                            
  |===============================================================       |  90%
  |                                                                            
  |======================================================================| 100%
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |===================================                                   |  50%
  |                                                                            
  |======================================================================| 100%
# Decision boundaries
g1 <- knnDecision(x.tr, y.tr, x.te, y.te, k = 10) +
   labs(title = "Original Features")
g2 <- knnDecision(new.data$new.tr, y.tr, new.data$new.te, y.te, k = 10) +
   labs(title = "KNN Features")
grid.arrange(g1, g2, ncol = 2)

最適なkを求める

ハイパーパラメタであるkを決める必要があるのですが、このパッケージには、様々な指標のもとでクロスバリデーションを行い、最適なkを決めることができるようです。「overall_error」「mean_error」「auc」「logloss」が利用可能です。

この例の場合、loglossのもとでクロスバリデーションした結果、loglossを最も小さくするkは10であることがわかります。

# Load dataset
library("mlbench")
data("Sonar", package = "mlbench")
x <- data.matrix(Sonar[, -61])
y <- Sonar$Class

# 5-fold CV using log-loss as evaluation metric
set.seed(123)
cv.out <- fastknnCV(x,
                    y,
                    k = 3:15,
                    method = "vote", #method = "dist"もある
                    folds = 5,
                    eval.metric = "logloss")
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |==============                                                        |  20%
  |                                                                            
  |============================                                          |  40%
  |                                                                            
  |==========================================                            |  60%
  |                                                                            
  |========================================================              |  80%
  |                                                                            
  |======================================================================| 100%
cv.out$cv_table
##       fold_1    fold_2    fold_3    fold_4    fold_5      mean  k
## 1  1.0185769 4.3488958 5.2774367 3.5390275 1.7839039 3.1935682  3
## 2  0.3302963 3.5667636 4.5300572 2.7776147 1.0502739 2.4510011  4
## 3  0.4046927 2.8207457 3.7682707 2.8348932 0.3424366 2.0342078  5
## 4  0.4655162 2.8307567 2.2128757 1.2546941 0.3610754 1.4249836  6
## 5  0.5302694 2.0623844 2.2345220 0.4860235 0.4155828 1.1457564  7
## 6  0.5468362 2.1109205 2.2323124 0.5080782 0.4481816 1.1692658  8
## 7  0.5782588 2.1231117 1.4562867 0.5286501 0.4478363 1.0268287  9
## 8  0.5947111 0.5556146 0.7008334 0.5156674 0.4565729 0.5646799 10
## 9  0.6152572 0.5767084 0.6726436 0.5208554 0.4610948 0.5693119 11
## 10 0.6263350 0.5951938 0.6712564 0.5339942 0.5083115 0.5870182 12
## 11 0.6119226 0.5866577 0.6610566 0.5430764 0.5154826 0.5836391 13
## 12 0.6137801 0.5715341 0.6623809 0.5434475 0.5298594 0.5842004 14
## 13 0.6132759 0.5883311 0.6594093 0.5739258 0.5543286 0.5978541 15