您好,登錄后才能下訂單哦!
前言:
最近在研究機(jī)器學(xué)習(xí),過程中的心得體會(huì)會(huì)記錄到blog里,文章與代碼均為原創(chuàng)。會(huì)不定期龜速更新,注意這不是正式的教程,因?yàn)楸救艘彩浅鯇W(xué)者,但是估計(jì)C#版本的代碼能幫到一些剛?cè)腴T的同學(xué)去理解復(fù)雜的公式。
------------------------ 我是分割線 ------------------------
k近鄰(k-Nearest Neighbor,KNN)算法,應(yīng)該是機(jī)器學(xué)習(xí)里最基礎(chǔ)的算法,其核心思想是:給定一個(gè)未知分類的樣本,如果與它最相似的k個(gè)已知樣本中的多數(shù)屬于某一個(gè)分類,那么這個(gè)未知樣本也屬于這個(gè)分類。
所謂相似,是指兩個(gè)樣本之間的歐氏距離小,其計(jì)算公式為:
其中Xi為樣本X的第i個(gè)特征。
k近鄰算法的優(yōu)點(diǎn)在于實(shí)現(xiàn)簡單,缺點(diǎn)在于時(shí)間和空間復(fù)雜度高。
上C#版代碼,這里取k=1,即只根據(jù)最相近的一個(gè)點(diǎn)確定分類:
首先是DataVector,包含N維數(shù)據(jù)和分類標(biāo)簽,用于表示一個(gè)樣本。
using System; namespace MachineLearning { /// <summary> /// 數(shù)據(jù)向量 /// </summary> /// <typeparam name="T"></typeparam> public class DataVector<T> { /// <summary> /// N維數(shù)據(jù) /// </summary> public T[] Data { get; private set; } /// <summary> /// 分類標(biāo)簽 /// </summary> public string Label { get; set; } /// <summary> /// 構(gòu)造 /// </summary> /// <param name="dimension">數(shù)據(jù)維度</param> public DataVector(int dimension) { Data = new T[dimension]; } public int Dimension { get { return this.Data.Length; } } } }
然后是核心算法:
using System; using System.Collections.Generic; namespace MachineLearning { /// <summary> /// k近鄰法 /// </summary> public class NearestNeighbour { private int m_K; private List<DataVector<double>> m_TrainingSet; public NearestNeighbour(int k = 1) { m_K = k; } /// <summary> /// 訓(xùn)練 /// </summary> /// <param name="trainingSet"></param> public void Train(List<DataVector<double>> trainingSet) { m_TrainingSet = trainingSet; } /// <summary> /// 分類 /// </summary> /// <param name="vector"></param> /// <returns></returns> public string Classify(DataVector<double> vector) { //K=1時(shí)可簡化處理提高效率 if(m_K == 1) { double minDist = double.PositiveInfinity; int targetIndex = -1; for(int i = 0;i < m_TrainingSet.Count;i++) { //計(jì)算距離 double distance = ComputeDistance(vector, m_TrainingSet[i], minDist); //找最小值 if(distance < minDist) { minDist = distance; targetIndex = i; } } return m_TrainingSet[targetIndex].Label; } else { var dict = new SortedDictionary<double, string>(); for(int i = 0;i < m_TrainingSet.Count;i++) { //計(jì)算距離并記錄 double distance = ComputeDistance(vector, m_TrainingSet[i]); dict[distance] = m_TrainingSet[i].Label; } //找最多的Label var labels = new List<string>(); int count = 0; foreach(var label in dict.Values) { labels.Add(label); if(++count > m_K - 1) break; } return GetMajorLabel(labels); } } /// <summary> /// 計(jì)算距離 /// </summary> /// <param name="v1"></param> /// <param name="v2"></param> /// <param name="minValue"></param> /// <returns></returns> private double ComputeDistance(DataVector<double> v1, DataVector<double> v2, double minValue = double.PositiveInfinity) { double distance = 0.0; minValue = minValue * minValue; for(int i = 0;i < v1.Data.Length;++i) { double diff = v1.Data[i] - v2.Data[i]; distance += diff * diff; //如果當(dāng)前累加的距離已經(jīng)大于給定的最小值,不用繼續(xù)計(jì)算了 if(distance > minValue) return double.PositiveInfinity; } return Math.Sqrt(distance); } /// <summary> /// 取多數(shù) /// </summary> /// <param name="dataSet"></param> /// <returns></returns> private string GetMajorLabel(List<string> labels) { var dict = new Dictionary<string, int>(); foreach(var item in labels) { if(!dict.ContainsKey(item)) dict[item] = 0; dict[item]++; } string label = string.Empty; int count = -1; foreach(var key in dict.Keys) { if(dict[key] > count) { label = key; count = dict[key]; } } return label; } } }
需要注意的是,計(jì)算距離時(shí),數(shù)量級大的維度會(huì)對距離影響大,因此大多數(shù)情況下,不能直接計(jì)算,要對原始數(shù)據(jù)做歸一化,并根據(jù)重要性進(jìn)行加權(quán)。歸一化可以使用公式:value = (old-min)/(max-min),其中old是原始值,max是所有數(shù)據(jù)的最大值,min是所有數(shù)據(jù)的最小值。這樣計(jì)算得到的value將落在0至1的區(qū)間上。
這個(gè)算法太簡單,暫時(shí)不上測試代碼了,有時(shí)間再補(bǔ)吧。
免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場,如果涉及侵權(quán)請聯(lián)系站長郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。