使用js机器学习库了解knn算法


K-近邻算法(KNN,k-NearestNeighbor)概述

首先K最近邻分类算法,是一个理论上比较成熟的方法,也是最简单的机器学习算法之一

可以把数据样本分为训练数据和测试数据,输入一个测试数据,依次计算和每一个训练数据的距离,然后取出k个距离最近的样本,然后在k个样本中做一个投票,票数多的样本所属的类即为测试数据的类。所以为了好投票在这里的k值选择一般为奇数。

KNN算法图解

如下图,绿色圆要被决定赋予哪个类,是红色三角形还是蓝色四方形?如果K=3,由于红色三角形所占比例为2/3,绿色圆将被赋予红色三角形那个类,如果K=5,由于蓝色四方形比例为3/5,因此绿色圆被赋予蓝色四方形类。

KNN算法决策过程

KNN算法的js库的使用

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
const csv = require('csvtojson')
const KNN = require("ml-knn")

/** csv 数据格式
* "sepalLength","sepalWidth","petalLength","petalWidth","type"
5.1,3.5,1.4,0.2,Iris-setosa
4.9,3.0,1.4,0.2,Iris-setosa
4.7,3.2,1.3,0.2,Iris-setosa
4.6,3.1,1.5,0.2,Iris-setosa
*/

const csvFilePath = "./Iris.csv"

const _ = require("lodash")
const prompt = require("prompt")

let seperationSize; // 区分测试数据和训练数据的比例
let knn
let csvData = [],
x = [],
y = []

let trainingSetX = [],
trainingSetY = [],
testSetX = [],
testSetY = []

csv().fromFile(csvFilePath)
.on("json", (jsonObj) => {
csvData.push(jsonObj)
})
.on("done", () => {
//console.log(csvData)
seperationSize = 0.7 * csvData.length

//对数据进行洗牌打散
csvData = _.shuffle(csvData)

dressData()
})


//填充训练数据
function dressData() {
const types = _.uniq(_.map(csvData, "type"))

const typeMap = _.reduce(types, (r, type, index) => {
_.set(r, type, index)
return r
}, {})

//console.log(typeMap)

_.each(csvData, (n) => {
x.push(_.chain(n).keys().dropRight(1).map((key) => {
return _.toNumber(n[key])
}).value())
y.push(typeMap[n.type])
})

//准备训练数据
trainingSetX = x.slice(0, seperationSize)
trainingSetY = y.slice(0, seperationSize)

//准备测试数据
testSetX = x.slice(seperationSize)
testSetY = y.slice(seperationSize)

// console.log(testSetX, testSetY)
//训练knn
const knn = new KNN(trainingSetX, trainingSetY, { k: 7 })

//使用训练好的knn 测试
const result = knn.predict(testSetX)

// console.log('@@@@@@@@@@@@@', result)
// const testSetLength = testSetX.length
// const predictionError = error(result, testSetY)
// console.log(`Test Set Size = ${testSetLength} and number of Misclassifications = ${predictionError}`)
// predict();
error(result, testSetY)
}

//使用预测出的数据和正确的数据 比较正确率
function error(predicted, expected) {
let errorCount = 0

_.each(predicted, (v, i) => {
if (v != expected[i]) {
console.log(`error: @predicted: ${v};@expected: ${expected[i]}`)
errorCount += 1
} else {
console.log(`good: @predicted: ${v};@expected: ${expected[i]}`)
}
})

console.log(`errror count is: ${errorCount}, all count is: ${predicted.length} , error ratio is: ${_.divide(errorCount, predicted.length)}`)
}

文章作者: Callable
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 Callable !
评论
  目录