溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊(cè)×
其他方式登錄
點(diǎn)擊 登錄注冊(cè) 即表示同意《億速云用戶服務(wù)條款》

怎么使用棧的記憶化搜索來加速子集和算法

發(fā)布時(shí)間:2021-10-23 11:39:11 來源:億速云 閱讀:134 作者:iii 欄目:編程語言

本篇內(nèi)容主要講解“怎么使用棧的記憶化搜索來加速子集和算法”,感興趣的朋友不妨來看看。本文介紹的方法操作簡(jiǎn)單快捷,實(shí)用性強(qiáng)。下面就讓小編來帶大家學(xué)習(xí)“怎么使用棧的記憶化搜索來加速子集和算法”吧!

所謂子集和就是在一個(gè)數(shù)組中找出它的子集,使得該子集的和等于某個(gè)固定值。

一般我們都是使用遞歸加回溯的方式來處理的,代碼如下(此處我們只找出一組滿足的條件即可)

public class SubSet {private List<Integer> list = new ArrayList<>();   //用于存放求取子集中的元素    @Getter    private List<Integer> res = new ArrayList<>();    //求取數(shù)組列表中元素和    public int getSum(List<Integer> list) {int sum = 0;        for(int i = 0;i < list.size();i++)
            sum += list.get(i);        return sum;    }public void getSubSet(int[] A, int m, int step) {if (res.size() > 0) {return;        }while(step < A.length) {list.add(A[step]);            if (getSum(list) == m) {if (getSum(res) == 0) {res.addAll(list);                }
            }
            step++;            getSubSet(A, m, step);            list.remove(list.size() - 1);   //回溯執(zhí)行語句,刪除列表最后一個(gè)元素        }
    }public static void main(String[] args) {
        SubSet test = new SubSet();        int[] A = new int[6];        for(int i = 0;i < 6;i++) {
            A[i] = i + 1;        }
        test.getSubSet(A, 8, 0);        System.out.println(test.getRes());    }
}

運(yùn)行結(jié)果

[1, 2, 5]

但是這個(gè)算法的時(shí)間復(fù)雜度非常高,是NP級(jí)別的。如果數(shù)據(jù)量比較大的時(shí)候,將很難完成運(yùn)算。

現(xiàn)在我們用棧和哈希緩存來加速這個(gè)算法。主要是緩存計(jì)算結(jié)果,不用每次都去getSum中把list的和算一遍。其思想主要是記憶化搜索,可以參考本人這篇博客動(dòng)態(tài)規(guī)劃、回溯、貪心,分治

public class SubSet {private List<Integer> list = new ArrayList<>();   //用于存放求取子集中的元素    @Getter    private List<Integer> res = new ArrayList<>();    private Deque<Integer> deque = new ArrayDeque<>();    private Map<String,Integer> map = new HashMap<>();    //求取數(shù)組列表中元素和    public int getSum(List<Integer> list) {int sum = 0;        for(int i = 0;i < list.size();i++)
            sum += list.get(i);        return sum;    }public void getSubSet(int[] A, int m, int step) {if (res.size() > 0) {return;        }while(step < A.length) {list.add(A[step]);            if (!map.containsKey(deque.toString())) {int sum = getSum(list);                deque.push(A[step]);                map.put(deque.toString(),sum);                if (sum == m) {if (getSum(res) == 0) {res.addAll(list);                    }
                }
            }else {int sum = map.get(deque.toString()) + A[step];                deque.push(A[step]);                map.put(deque.toString(),sum);                if (sum == m) {if (getSum(res) == 0) {res.addAll(list);                    }
                }
            }
            step++;            getSubSet(A, m, step);            list.remove(list.size() - 1);   //回溯執(zhí)行語句,刪除列表最后一個(gè)元素            deque.pop();        }
    }public static void main(String[] args) {
        SubSet test = new SubSet();        int[] A = new int[6];        for(int i = 0;i < 6;i++) {
            A[i] = i + 1;        }
        test.getSubSet(A, 8, 0);        System.out.println(test.getRes());    }
}

運(yùn)算結(jié)果

[1, 2, 5]

但C#無法滿足獲取棧的值,只能獲取棧的類型,如果我們用遍歷的方式去獲取棧的值又回到了以前NP級(jí)的時(shí)間復(fù)雜度,故直接使用數(shù)字來做哈希表的鍵。內(nèi)容如下

using System;
using System.Collections.Generic;
using System.Collections;
using System.Text.RegularExpressions;
using System.Linq;
using System.Text;
using System.Threading.Tasks;

namespace ConsoleApplication1
{
    class Program
    {
        private class Oranize
        {
            public List<decimal> array = new List<decimal>();
            public List<decimal> res = new List<decimal>();
            public Stack<decimal> stack = new Stack<decimal>();
            public Hashtable table = new Hashtable();
            public decimal index = 0;

            public decimal getSum(List<decimal> list)
            {
                decimal sum = 0;
                for (int i = 0; i < list.Count; i++)
                {
                    sum += list[i];
                }
                return sum;
            }

            public String stackValue(Stack<decimal> stack)
            {
                StringBuilder sb = new StringBuilder();
                foreach (decimal s in stack)
                {
                    sb.Append(s.ToString());
                }
                return sb.ToString();
            }

            public void org(decimal[] arr,decimal all, int step)
            {
                if (res.Count > 0)
                {
                    return;
                }
                while (step < arr.Length)
                {
                    array.Add(arr[step]);                    
                    if (!table.ContainsKey(index.ToString()))
                    {
                        decimal sum = getSum(array);
                        stack.Push(index);
                        table.Add(stack.Peek().ToString(), sum);
                        if (sum == all)
                        {
                            if (getSum(res) == 0)
                            {
                                foreach (decimal a in array)
                                {
                                    res.Add(a);
                                }
                            }
                        }
                    }
                    else
                    {
                        decimal sum = 0;
                        if (stack.Count > 0)
                        {
                            sum = Convert.ToDecimal(table[stack.Peek().ToString()]) + arr[step];
                        }
                        else
                        {
                            sum = Convert.ToDecimal(table["0"]) + arr[step];
                        }
                        index++;
                        stack.Push(index);
                        if (table.ContainsKey(stack.Peek().ToString()))
                        {
                            table.Remove(stack.Peek().ToString());
                        }
                        table.Add(stack.Peek().ToString(), sum);
                        if (sum == all)
                        {
                            if (getSum(res) == 0)
                            {
                                foreach (decimal a in array)
                                {
                                    res.Add(a);
                                }
                            }
                        }
                    }
                    step++;
                    org(arr, all, step);
                    array.RemoveAt(array.Count - 1);
                    stack.Pop();
                }
            }
        }
        static void Main(string[] args)
        {
            decimal[] A = new decimal[6];
            for (int i = 0; i < 6; i++)
            {
                A[i] = i + 1;
            }
            Oranize oranize = new Oranize();
            oranize.org(A, 8, 0);

            foreach (decimal r in oranize.res)
            {
                Console.Write(r + ",");
            }
            Console.ReadLine();
        }
    }
}

這里我們可以看到如果使用stackValue來獲取棧的各個(gè)值的字符串是不可取的,同樣會(huì)非常慢。

由于C#本身的Hashtable在數(shù)據(jù)量大的情況下存在溢出風(fēng)險(xiǎn),所以我們要重寫哈希表。重寫的哈希表的每個(gè)節(jié)點(diǎn)由紅黑樹組成,由于我們并不需要?jiǎng)h除哈希表內(nèi)的元素,所以就不寫紅黑樹和哈希表的刪除方法。

        private class RedBlackTreeMap
        {
            private static bool RED = true;
            private static bool BLACK = false;

            private class Node
            {
                public String key;
                public decimal value;
                public Node left;
                public Node right;
                public bool color;

                public Node(String key,decimal value,Node left,Node right,bool color)
                {
                    this.key = key;
                    this.value = value;
                    this.left = left;
                    this.right = right;
                    this.color = color;
                }

                public Node(String key): this(key, 0, null, null, RED)
                { }

                public Node(String key,decimal value): this(key, value, null, null, RED)
                { }                                    
            }

            private Node root;
            private int size;
            public ISet<String> keySet = new HashSet<String>();

            public RedBlackTreeMap()
            {
                root = null;
                size = 0;
            }

            private bool isRed(Node node)
            {
                if (node == null)
                {
                    return BLACK;
                }
                return node.color;
            }

            private Node leftRotate(Node node)
            {
                Node ret = node.right;
                Node retLeft = ret.left;
                node.right = retLeft;
                ret.left = node;
                ret.color = node.color;
                node.color = RED;
                return ret;
            }

            private Node rightRotate(Node node)
            {
                Node ret = node.left;
                Node retRight = ret.right;
                node.left = retRight;
                ret.right = node;
                ret.color = node.color;
                node.color = RED;
                return ret;
            }

            private void flipColors(Node node)
            {
                node.color = RED;
                node.left.color = BLACK;
                node.right.color = BLACK;
            }

            public void add(String key,decimal value)
            {
                root = add(root, key, value);
                keySet.Add(key);
            }

            private Node add(Node node,String key,decimal value)
            {
                if (node == null)
                {
                    size++;
                    return new Node(key, value);
                }
                if (key.CompareTo(node.key) < 0)
                {
                    node.left = add(node.left, key, value);
                }else if (key.CompareTo(node.key) > 0)
                {
                    node.right = add(node.right, key, value);
                }else
                {
                    node.value = value;
                }
                if (isRed(node.right) && !isRed(node.left))
                {
                    node = leftRotate(node);
                }
                if (isRed(node.left) && isRed(node.left.left))
                {
                    node = rightRotate(node);
                }
                if (isRed(node.left) && isRed(node.right))
                {
                    flipColors(node);
                }
                return node;
            }

            public bool contains(String key)
            {
                return getNode(root, key) != null;
            }

            public decimal get(String key)
            {
                Node node = getNode(root, key);
                return node == null ? 0 : node.value;
            }

            public void set(String key,decimal value)
            {
                Node node = getNode(root, key);
                if (node == null)
                {
                    throw new ArgumentException(key + "不存在");
                }
                node.value = value;
            }

            public int getSize()
            {
                return size;
            }

            public bool isEmpty()
            {
                return size == 0;
            }

            private Node getNode(Node node,String key)
            {
                if (node == null)
                {
                    return null;
                }
                if (key.CompareTo(node.key) == 0)
                {
                    return node;
                }else if (key.CompareTo(node.key) < 0)
                {
                    return getNode(node.left, key);
                }else
                {
                    return getNode(node.right, key);
                }
            }
        }

        private class HashFind
        {
            private int[] capacity = {53,97,193,389,769,1543,3079,6151,12289,24593,
            49157,98317,196613,393241,786433,1572869,3145739,
            6291469,12582917,25165843,50331653,100663319,
            201326611,402653189,805306457,1610612741};

            //容忍度上界
            private static int upperTol = 10;
            //容忍度下屆
            private static int lowerTol = 2;
            private int capacityIndex = 0;

            private RedBlackTreeMap[] tables;
            private int M;
            private int size;

            public HashFind()
            {
                this.M = capacity[capacityIndex];
                this.size = 0;
                tables = new RedBlackTreeMap[M];
                for (int i = 0; i < M; i++)
                {
                    tables[i] = new RedBlackTreeMap();
                }
            }

            private int hash(String key)
            {
                return (key.GetHashCode() & 0x7fffffff) % M;
            }

            public void add(String key,decimal value)
            {
                RedBlackTreeMap map = tables[hash(key)];
                if (map.contains(key))
                {
                    map.add(key, value);
                }else
                {
                    map.add(key, value);
                    size++;
                    if (size >= upperTol * M && capacityIndex + 1 < capacity.Length)
                    {
                        capacityIndex++;
                        resize(capacity[capacityIndex]);
                    }
                }
            }

            public bool contains(String key)
            {
                int index = hash(key);
                return tables[index].contains(key);
            }

            public decimal get(String key)
            {
                int index = hash(key);
                return tables[index].get(key);
            }

            public void set(String key,decimal value)
            {
                int index = hash(key);
                RedBlackTreeMap map = tables[index];
                if(!map.contains(key))
                {
                    throw new ArgumentException(key + "不存在");
                }
                map.add(key, value);
            }

            public int getSize()
            {
                return size;
            }

            public bool isEmpty()
            {
                return size == 0;
            }

            private void resize(int newM)
            {
                RedBlackTreeMap[] newTables = new RedBlackTreeMap[newM];
                for (int i = 0; i < newM; i++)
                {
                    newTables[i] = new RedBlackTreeMap();
                }
                int oldM = this.M;
                this.M = newM;
                for (int i = 0; i < oldM; i++)
                {
                    RedBlackTreeMap map = tables[i];
                    foreach (String key in map.keySet)
                    {
                        int index = hash(key);
                        newTables[index].add(key, map.get(key));
                    }
                }
                this.tables = newTables;
            }
        }

        private class Oranize
        {
            public List<decimal> array = new List<decimal>();
            public List<decimal> res = new List<decimal>();
            public Stack<decimal> stack = new Stack<decimal>();
            public HashFind table = new HashFind();
            public decimal index = 0;

            public decimal getSum(List<decimal> list)
            {
                decimal sum = 0;
                for (int i = 0; i < list.Count; i++)
                {
                    sum += list[i];
                }
                return sum;
            }

            //public String stackValue(Stack<decimal> stack)
            //{
            //    StringBuilder sb = new StringBuilder();
            //    foreach (decimal s in stack)
            //    {
            //        sb.Append(s.ToString());
            //    }
            //    return sb.ToString();
            //}

            public void org(decimal[] arr, decimal all, int step)
            {
                if (res.Count > 0)
                {
                    return;
                }
                while (step < arr.Length)
                {
                    array.Add(arr[step]);
                    if (!table.contains(index.ToString()))
                    {
                        decimal sum = getSum(array);
                        stack.Push(index);
                        table.add(stack.Peek().ToString(), sum);
                        if (sum == all)
                        {
                            if (getSum(res) == 0)
                            {
                                foreach (decimal a in array)
                                {
                                    res.Add(a);
                                }
                            }
                        }
                    }
                    else
                    {
                        decimal sum = 0;
                        if (stack.Count > 0)
                        {
                            sum = Convert.ToDecimal(table.get(stack.Peek().ToString())) + arr[step];
                        }
                        else
                        {
                            sum = Convert.ToDecimal(table.get("0")) + arr[step];
                        }
                        index++;
                        stack.Push(index);
                        if (!table.contains(stack.Peek().ToString()))
                        {
                            table.add(stack.Peek().ToString(), sum);
                        }
                        if (sum == all)
                        {
                            if (getSum(res) == 0)
                            {
                                foreach (decimal a in array)
                                {
                                    res.Add(a);
                                }
                            }
                        }
                    }
                    step++;
                    org(arr, all, step);
                    array.RemoveAt(array.Count - 1);
                    stack.Pop();
                }
            }
        }

雖然該算法進(jìn)行了加速,但是能否算出,依然在于數(shù)組元素的個(gè)數(shù)所組成的和的組合數(shù),比如有1、2、3、4四個(gè)數(shù),則這四個(gè)數(shù)的和的組合數(shù)為1、2、3、4、1+2、1+2+3、1+2+4、1+2+3+4、1+3、1+3+4、1+4、2+3、2+3+4、2+4、3+4總共15個(gè)。

我們可以用計(jì)算組合數(shù)算法來進(jìn)行驗(yàn)證,該算法也是使用遞歸加記憶化搜索的方式

public class Combine {private static Map<String,Long> map= new HashMap<>();    /**     * 計(jì)算從m個(gè)元素中拿出n個(gè)元素的組合數(shù)     * @param m     * @param n     * @return     */    private static long comb(int m,int n){
        String key= m+","+n;        if(n == 0)return 1;        if (n == 1)return m;        if(n > m / 2)return comb(m,m-n);        if(n > 1){if(!map.containsKey(key))map.put(key, comb(m-1,n-1)+comb(m-1,n));            return map.get(key);        }return -1;    }public static void main(String[] args) {long total = 0;        for (int i = 1 ; i <= 4; i++) {
            total += comb(4,i);        }
        System.out.println(total);    }
}

運(yùn)行結(jié)果

15

我們現(xiàn)在的主要目的是尋找可計(jì)算的節(jié)點(diǎn),我們可以先給出一個(gè)比較大的數(shù),比如一個(gè)數(shù)組中有40個(gè)元素

public static void main(String[] args) {long total = 0;    for (int i = 1 ; i <= 40; i++) {
        total += comb(40,i);    }
    System.out.println(total);}

運(yùn)行結(jié)果

1099511627775

由結(jié)果可知,40個(gè)數(shù)的組合數(shù)達(dá)到了萬億級(jí)別,一般我們計(jì)算機(jī)的計(jì)算級(jí)數(shù)量在億級(jí)別就差不多了,再多的話就比較難算的出來了。當(dāng)然這里我的個(gè)人建議是數(shù)組元素?cái)?shù)量在28個(gè)

public static void main(String[] args) {long total = 0;    for (int i = 1 ; i <= 28; i++) {
        total += comb(28,i);    }
    System.out.println(total);}

運(yùn)行結(jié)果

268435455

這里是2.6億,最后我們來看一下30的組合數(shù)

public static void main(String[] args) {long total = 0;    for (int i = 1 ; i <= 30; i++) {
        total += comb(30,i);    }
    System.out.println(total);}

運(yùn)行結(jié)果

1073741823

運(yùn)行結(jié)果為10億,所以我們可以看出從28到30,增長(zhǎng)的組合數(shù)絕對(duì)不是一點(diǎn)點(diǎn)。這是一個(gè)幾何級(jí)數(shù)的增長(zhǎng)。

到此,相信大家對(duì)“怎么使用棧的記憶化搜索來加速子集和算法”有了更深的了解,不妨來實(shí)際操作一番吧!這里是億速云網(wǎng)站,更多相關(guān)內(nèi)容可以進(jìn)入相關(guān)頻道進(jìn)行查詢,關(guān)注我們,繼續(xù)學(xué)習(xí)!

向AI問一下細(xì)節(jié)

免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場(chǎng),如果涉及侵權(quán)請(qǐng)聯(lián)系站長(zhǎng)郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。

AI