您好,登錄后才能下訂單哦!
Knn算法的核心思想是如果一個樣本在特征空間中的K個最相鄰的樣本中的大多數(shù)屬于某一個類別,則該樣本也屬于這個類別,并具有這個類別上樣本的特性。該方法在確定分類決策上只依據(jù)最鄰近的一個或者幾個樣本的類別來決定待分樣本所屬的類別。Knn方法在類別決策時,只與極少量的相鄰樣本有關(guān)。由于Knn方法主要靠周圍有限的鄰近的樣本,而不是靠判別類域的方法來確定所屬類別的,因此對于類域的交叉或重疊較多的待分樣本集來說,Knn方法較其他方法更為合適。
Knn算法流程如下:
1. 計算當(dāng)前測試數(shù)據(jù)與訓(xùn)練數(shù)據(jù)中的每條數(shù)據(jù)的距離
2. 圈定距離最近的K個訓(xùn)練對象,作為測試對象的近鄰
3. 計算這K個訓(xùn)練對象中出現(xiàn)最多的那個類別,并將這個類別作為當(dāng)前測試數(shù)據(jù)的類別
以上流程是Knn的大致流程,按照這個流程實現(xiàn)的MR效率并不高,可以在這之上進(jìn)行優(yōu)化。在這里只寫,跟著這個流程走的MR實現(xiàn)過程。
Mapper的設(shè)計:
由于測試數(shù)據(jù)相比于訓(xùn)練數(shù)據(jù)來說,會小很多,因此將測試數(shù)據(jù)用Java API讀取,放到內(nèi)存中。所以,在setup中需要對測試數(shù)據(jù)進(jìn)行初始化。在map中,計算當(dāng)前測試數(shù)據(jù)與每條訓(xùn)練數(shù)據(jù)的距離,Mapper的值類型為:<Object, Text, IntWritable,MyWritable>。map輸出鍵類型為IntWritable,存放當(dāng)前測試數(shù)據(jù)的下標(biāo),輸出值類型為MyWritable,這是自定義值類型,其中存放的是距離以及與測試數(shù)據(jù)比較的訓(xùn)練數(shù)據(jù)的類別。
public class KnnMapper extends Mapper<Object, Text, IntWritable,MyWritable> { Logger log = LoggerFactory.getLogger(KnnMapper.class); private List<float[]> testData; @Override protected void setup(Context context) throws IOException, InterruptedException { // TODO Auto-generated method stub Configuration conf= context.getConfiguration(); conf.set("fs.defaultFS", "master:8020"); String testPath= conf.get("TestFilePath"); Path testDataPath= new Path(testPath); FileSystem fs = FileSystem.get(conf); this.testData = readTestData(fs,testDataPath); } @Override protected void map(Object key, Text value, Context context) throws IOException, InterruptedException { // TODO Auto-generated method stub String[] line = value.toString().split(","); float[] trainData = new float[line.length-1]; for(int i=0;i<trainData.length;i++){ trainData[i] = Float.valueOf(line[i]); log.info("訓(xùn)練數(shù)據(jù):"+line[i]+"類別:"+line[line.length-1]); } for(int i=0; i< this.testData.size();i++){ float[] testI = this.testData.get(i); float distance = Outh(testI, trainData); log.info("距離:"+distance); context.write(new IntWritable(i), new MyWritable(distance, line[line.length-1])); } } private List<float[]> readTestData(FileSystem fs,Path Path) throws IOException { //補(bǔ)充代碼完整 FSDataInputStream data = fs.open(Path); BufferedReader bf = new BufferedReader(new InputStreamReader(data)); String line = ""; List<float[]> list = new ArrayList<>(); while ((line = bf.readLine()) != null) { String[] items = line.split(","); float[] item = new float[items.length]; for(int i=0;i<items.length;i++){ item[i] = Float.valueOf(items[i]); } list.add(item); } return list; } // 計算歐式距離 private static float Outh(float[] testData, float[] inData) { float distance =0.0f; for(int i=0;i<testData.length;i++){ distance += (testData[i]-inData[i])*(testData[i]-inData[i]); } distance = (float)Math.sqrt(distance); return distance; } }
自定義值類型MyWritable如下:
public class MyWritable implements Writable{ private float distance; private String label; public MyWritable() { // TODO Auto-generated constructor stub } public MyWritable(float distance, String label){ this.distance = distance; this.label = label; } @Override public String toString() { // TODO Auto-generated method stub return this.distance+","+this.label; } @Override public void write(DataOutput out) throws IOException { // TODO Auto-generated method stub out.writeFloat(distance); out.writeUTF(label); } @Override public void readFields(DataInput in) throws IOException { // TODO Auto-generated method stub this.distance = in.readFloat(); this.label = in.readUTF(); } public float getDistance() { return distance; } public void setDistance(float distance) { this.distance = distance; } public String getLabel() { return label; } public void setLabel(String label) { this.label = label; } }
在Reducer端中,需要初始化參數(shù)K,也就是圈定距離最近的K個對象的K值。在reduce中需要對距離按照從小到大的距離排序,然后選取前K條數(shù)據(jù),再計算這K條數(shù)據(jù)中,出現(xiàn)次數(shù)最多的那個類別并將這個類別與測試數(shù)據(jù)的下標(biāo)相對應(yīng)并以K,V的形式輸出到HDFS上。
public class KnnReducer extends Reducer<IntWritable, MyWritable, IntWritable, Text> { private int K; @Override protected void setup(Context context) throws IOException, InterruptedException { // TODO Auto-generated method stub this.K = context.getConfiguration().getInt("K", 5); } @Override /*** * key => 0 * values =>([1,lable1],[2,lable2],[3,label2],[2.5,lable2]) */ protected void reduce(IntWritable key, Iterable<MyWritable> values, Context context) throws IOException, InterruptedException { // TODO Auto-generated method stub MyWritable[] mywrit = new MyWritable[K]; for(int i=0;i<K;i++){ mywrit[i] = new MyWritable(Float.MAX_VALUE, "-1"); } // 找出距離最小的前k個 for (MyWritable m : values) { float distance = m.getDistance(); String label = m.getLabel(); for(MyWritable m1: mywrit){ if (distance < m1.getDistance()){ m1.setDistance(distance); m1.setLabel(label); } } } // 找出前k個中,出現(xiàn)次數(shù)最多的類別 String[] testClass = new String[K]; for(int i=0;i<K;i++){ testClass[i] = mywrit[i].getLabel(); } String countMost = mostEle(testClass); context.write(key, new Text(countMost)); } public static String mostEle(String[] strArray) { HashMap<String, Integer> map = new HashMap<>(); for (int i = 0; i < strArray.length; i++) { String str = strArray[i]; if (map.containsKey(str)) { int tmp = map.get(str); map.put(str, tmp+1); }else{ map.put(str, 1); } } // 得到hashmap中值最大的鍵,也就是出現(xiàn)次數(shù)最多的類別 Collection<Integer> count = map.values(); int maxCount = Collections.max(count); String maxString = ""; for(Map.Entry<String, Integer> entry: map.entrySet()){ if (maxCount == entry.getValue()) { maxString = entry.getKey(); } } return maxString; } }
最后輸出結(jié)果如下:
以上就是本文的全部內(nèi)容,希望對大家的學(xué)習(xí)有所幫助,也希望大家多多支持億速云。
免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點不代表本網(wǎng)站立場,如果涉及侵權(quán)請聯(lián)系站長郵箱:is@yisu.com進(jìn)行舉報,并提供相關(guān)證據(jù),一經(jīng)查實,將立刻刪除涉嫌侵權(quán)內(nèi)容。