溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點(diǎn)擊 登錄注冊 即表示同意《億速云用戶服務(wù)條款》

機(jī)器學(xué)習(xí)算法:k近鄰

發(fā)布時(shí)間:2020-06-17 05:28:07 來源:網(wǎng)絡(luò) 閱讀:1425 作者:BoyTNT 欄目:編程語言

前言:

最近在研究機(jī)器學(xué)習(xí),過程中的心得體會(huì)會(huì)記錄到blog里,文章與代碼均為原創(chuàng)。會(huì)不定期龜速更新,注意這不是正式的教程,因?yàn)楸救艘彩浅鯇W(xué)者,但是估計(jì)C#版本的代碼能幫到一些剛?cè)腴T的同學(xué)去理解復(fù)雜的公式。


------------------------ 我是分割線 ------------------------


k近鄰(k-Nearest Neighbor,KNN)算法,應(yīng)該是機(jī)器學(xué)習(xí)里最基礎(chǔ)的算法,其核心思想是:給定一個(gè)未知分類的樣本,如果與它最相似的k個(gè)已知樣本中的多數(shù)屬于某一個(gè)分類,那么這個(gè)未知樣本也屬于這個(gè)分類。


所謂相似,是指兩個(gè)樣本之間的歐氏距離小,其計(jì)算公式為:

機(jī)器學(xué)習(xí)算法:k近鄰

其中Xi為樣本X的第i個(gè)特征。


k近鄰算法的優(yōu)點(diǎn)在于實(shí)現(xiàn)簡單,缺點(diǎn)在于時(shí)間和空間復(fù)雜度高。


上C#版代碼,這里取k=1,即只根據(jù)最相近的一個(gè)點(diǎn)確定分類:

首先是DataVector,包含N維數(shù)據(jù)和分類標(biāo)簽,用于表示一個(gè)樣本。

using System;

namespace MachineLearning
{
    /// <summary>
    /// 數(shù)據(jù)向量
    /// </summary>
    /// <typeparam name="T"></typeparam>
    public class DataVector<T>
    {
        /// <summary>
        /// N維數(shù)據(jù)
        /// </summary>
        public T[] Data { get; private set; }
        /// <summary>
        /// 分類標(biāo)簽
        /// </summary>
        public string Label { get; set; }

        /// <summary>
        /// 構(gòu)造
        /// </summary>
        /// <param name="dimension">數(shù)據(jù)維度</param>
        public DataVector(int dimension)
        {
            Data = new T[dimension];
        }
        
        public int Dimension
        {
            get { return this.Data.Length; }
        }
    }
}


然后是核心算法:

using System;
using System.Collections.Generic;

namespace MachineLearning
{
    /// <summary>
    /// k近鄰法
    /// </summary>
    public class NearestNeighbour
    {
        private int m_K;
        private List<DataVector<double>> m_TrainingSet;
        
        public NearestNeighbour(int k = 1)
        {
            m_K = k;
        }
        
        /// <summary>
        /// 訓(xùn)練
        /// </summary>
        /// <param name="trainingSet"></param>
        public void Train(List<DataVector<double>> trainingSet)
        {
            m_TrainingSet = trainingSet;
        }

        /// <summary>
        /// 分類
        /// </summary>
        /// <param name="vector"></param>
        /// <returns></returns>
        public string Classify(DataVector<double> vector)
        {
            //K=1時(shí)可簡化處理提高效率
            if(m_K == 1)
            {
                double minDist = double.PositiveInfinity;
                int targetIndex = -1;
                for(int i = 0;i < m_TrainingSet.Count;i++)
                {
                    //計(jì)算距離
                    double distance = ComputeDistance(vector, m_TrainingSet[i], minDist);

                    //找最小值
                    if(distance < minDist)
                    {
                        minDist = distance;
                        targetIndex = i;
                    }
                }
            
                return m_TrainingSet[targetIndex].Label;
            }
            else
            {
                var dict = new SortedDictionary<double, string>();
                
                for(int i = 0;i < m_TrainingSet.Count;i++)
                {
                    //計(jì)算距離并記錄
                    double distance = ComputeDistance(vector, m_TrainingSet[i]);
                    dict[distance] = m_TrainingSet[i].Label;
                }
                
                //找最多的Label
                var labels = new List<string>();
                int count = 0;
                foreach(var label in dict.Values)
                {
                    labels.Add(label);
                    if(++count > m_K - 1)
                        break;
                }

                return GetMajorLabel(labels);
            }
        }
    
        /// <summary>
        /// 計(jì)算距離
        /// </summary>
        /// <param name="v1"></param>
        /// <param name="v2"></param>
        /// <param name="minValue"></param>
        /// <returns></returns>
        private double ComputeDistance(DataVector<double> v1, DataVector<double> v2, double minValue = double.PositiveInfinity)
        {
            double distance = 0.0;
            minValue = minValue * minValue;
            for(int i = 0;i < v1.Data.Length;++i)
            {
                double diff = v1.Data[i] - v2.Data[i];
                distance += diff * diff;
            
                //如果當(dāng)前累加的距離已經(jīng)大于給定的最小值,不用繼續(xù)計(jì)算了
                if(distance > minValue)
                    return double.PositiveInfinity;
            }

            return Math.Sqrt(distance);
        }
        
        /// <summary>
        /// 取多數(shù)
        /// </summary>
        /// <param name="dataSet"></param>
        /// <returns></returns>
        private string GetMajorLabel(List<string> labels)
        {
            var dict = new Dictionary<string, int>();
            foreach(var item in labels)
            {
                if(!dict.ContainsKey(item))
                    dict[item] = 0;
                dict[item]++;
            }

            string label = string.Empty;
            int count = -1;
            foreach(var key in dict.Keys)
            {
                if(dict[key] > count)
                {
                    label = key;
                    count = dict[key];
                }
            }
            
            return label;
        }
    }
}


需要注意的是,計(jì)算距離時(shí),數(shù)量級大的維度會(huì)對距離影響大,因此大多數(shù)情況下,不能直接計(jì)算,要對原始數(shù)據(jù)做歸一化,并根據(jù)重要性進(jìn)行加權(quán)。歸一化可以使用公式:value = (old-min)/(max-min),其中old是原始值,max是所有數(shù)據(jù)的最大值,min是所有數(shù)據(jù)的最小值。這樣計(jì)算得到的value將落在0至1的區(qū)間上。


這個(gè)算法太簡單,暫時(shí)不上測試代碼了,有時(shí)間再補(bǔ)吧。


向AI問一下細(xì)節(jié)

免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場,如果涉及侵權(quán)請聯(lián)系站長郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。

AI