首页 | 联系我们 | 叶凡网络官方QQ群:323842844
游客,欢迎您! 请登录 免费注册 忘记密码
您所在的位置:首页 > 新闻中心 > 行业新闻 > 正文

名扬互联:Java实现双数组Trie树(DoubleArrayTrie,DAT)

作者:cocomyyz 来源: 日期:2013-12-5 10:36:08 人气:0 加入收藏 评论:0 标签:

传统的Trie实现简单,但是占用的空间实在是难以接受,特别是当字符集不仅限于英文26个字符的时候,爆炸起来的空间根本无法接受。

双数组Trie就是优化了空间的Trie树,原理本文就不讲了,请参考An Efficient Implementation of Trie Structures,本程序的编写也是参考这篇论文的。

关于几点论文没有提及的细节和与论文不一一致的实现:

1.对于插入字符串,如果有一个字符串是另一个字符串的子串的话,我是将结束符也作为一条边,产生一个新的结点,这个结点新节点的Base我置为0

所以一个字符串结束也有2中情况:一个是Base值为负,存储剩余字符(可能只有一个结束符)到Tail数组;另一个是Base为0。

所以在查询的时候要考虑一下这两种情况

2.对于第一种冲突(论文中的Case 3),可能要将Tail中的字符串取出一部分,作为边放到索引中。论文是使用将尾串左移的方式,我的方式直接修改Base值,而不是移动尾串。

下面是java实现的代码,可以处理相同字符串插入,子串的插入等情况

  1. /*

  2. * Name:   Double Array Trie

  3. * Author: Yaguang Ding

  4. * Mail: dingyaguang117@gmail.com

  5. * Blog: blog.csdn.net/dingyaguang117

  6. * Date:   2012/5/21

  7. * Note: a word ends may be either of these two case:

  8. * 1. Base[cur_p] == pos  ( pos<0 and Tail[-pos] == 'END_CHAR' )

  9. * 2. Check[Base[cur_p] + Code('END_CHAR')] ==  cur_p

  10. */

  11. import java.util.ArrayList;  

  12. import java.util.HashMap;  

  13. import java.util.Map;  

  14. import java.util.Arrays;  

  15. public class DoubleArrayTrie {  

  16.    final char END_CHAR = '\0';  

  17.    final int DEFAULT_LEN = 1024;  

  18.    int Base[]  = new int [DEFAULT_LEN];  

  19.    int Check[] = new int [DEFAULT_LEN];  

  20.    char Tail[] = new char [DEFAULT_LEN];  

  21.    int Pos = 1;  

  22.    Map<Character ,Integer> CharMap = new HashMap<Character,Integer>();  

  23.    ArrayList<Character> CharList = new ArrayList<Character>();  

  24.      

  25.    public DoubleArrayTrie()  

  26.    {  

  27.        Base[1] = 1;  

  28.          

  29.        CharMap.put(END_CHAR,1);  

  30.        CharList.add(END_CHAR);  

  31.        CharList.add(END_CHAR);  

  32.        for(int i=0;i<26;++i)  

  33.        {  

  34.            CharMap.put((char)('a'+i),CharMap.size()+1);  

  35.            CharList.add((char)('a'+i));  

  36.        }  

  37.          

  38.    }  

  39.    private void Extend_Array()  

  40.    {  

  41.        Base = Arrays.copyOf(Base, Base.length*2);  

  42.        Check = Arrays.copyOf(Check, Check.length*2);  

  43.    }  

  44.      

  45.    private void Extend_Tail()  

  46.    {  

  47.        Tail = Arrays.copyOf(Tail, Tail.length*2);  

  48.    }  

  49.      

  50.    private int GetCharCode(char c)  

  51.    {  

  52.        if (!CharMap.containsKey(c))  

  53.        {  

  54.            CharMap.put(c,CharMap.size()+1);  

  55.            CharList.add(c);  

  56.        }  

  57.        return CharMap.get(c);  

  58.    }  

  59.    private int CopyToTailArray(String s,int p)  

  60.    {  

  61.        int _Pos = Pos;  

  62.        while(s.length()-p+1 > Tail.length-Pos)  

  63.        {  

  64.            Extend_Tail();  

  65.        }  

  66.        for(int i=p; i<s.length();++i)  

  67.        {  

  68.            Tail[_Pos] = s.charAt(i);  

  69.            _Pos++;  

  70.        }  

  71.        return _Pos;  

  72.    }  

  73.      

  74.    private int x_check(Integer []set)  

  75.    {  

  76.        for(int i=1; ; ++i)  

  77.        {  

  78.            boolean flag = true;  

  79.            for(int j=0;j<set.length;++j)  

  80.            {  

  81.                int cur_p = i+set[j];  

  82.                if(cur_p>= Base.length) Extend_Array();  

  83.                if(Base[cur_p]!= 0 || Check[cur_p]!= 0)  

  84.                {  

  85.                    flag = false;  

  86.                    break;  

  87.                }  

  88.            }  

  89.            if (flag) return i;  

  90.        }  

  91.    }  

  92.      

  93.    private ArrayList<Integer> GetChildList(int p)  

  94.    {  

  95.        ArrayList<Integer> ret = new ArrayList<Integer>();  

  96.        for(int i=1; i<=CharMap.size();++i)  

  97.        {  

  98.            if(Base[p]+i >= Check.length) break;  

  99.            if(Check[Base[p]+i] == p)  

  100.            {  

  101.                ret.add(i);  

  102.            }  

  103.        }  

  104.        return ret;  

  105.    }  

  106.      

  107.    private boolean TailContainString(int start,String s2)  

  108.    {  

  109.        for(int i=0;i<s2.length();++i)  

  110.        {  

  111.            if(s2.charAt(i) != Tail[i+start]) return false;  

  112.        }  

  113.          

  114.        return true;  

  115.    }  

  116.    private boolean TailMatchString(int start,String s2)  

  117.    {  

  118.        s2 += END_CHAR;  

  119.        for(int i=0;i<s2.length();++i)  

  120.        {  

  121.            if(s2.charAt(i) != Tail[i+start]) return false;  

  122.        }  

  123.        return true;  

  124.    }  

  125.      

  126.      

  127.    public void Insert(String s) throws Exception  

  128.    {  

  129.        s += END_CHAR;  

  130.          

  131.        int pre_p = 1;  

  132.        int cur_p;  

  133.        for(int i=0; i<s.length(); ++i)  

  134.        {  

  135.            //获取状态位置

  136.            cur_p = Base[pre_p]+GetCharCode(s.charAt(i));  

  137.            //如果长度超过现有,拓展数组

  138.            if (cur_p >= Base.length) Extend_Array();  

  139.              

  140.            //空闲状态

  141.            if(Base[cur_p] == 0 && Check[cur_p] == 0)  

  142.            {  

  143.                Base[cur_p] = -Pos;  

  144.                Check[cur_p] = pre_p;  

  145.                Pos = CopyToTailArray(s,i+1);  

  146.                break;  

  147.            }else

  148.            //已存在状态

  149.            if(Base[cur_p] > 0 && Check[cur_p] == pre_p)  

  150.            {  

  151.                pre_p = cur_p;  

  152.                continue;  

  153.            }else

  154.            //冲突 1:遇到 Base[cur_p]小于0的,即遇到一个被压缩存到Tail中的字符串

  155.            if(Base[cur_p] < 0 && Check[cur_p] == pre_p)  

  156.            {  

  157.                int head = -Base[cur_p];  

  158.                  

  159.                if(s.charAt(i+1)== END_CHAR && Tail[head]==END_CHAR)    //插入重复字符串

  160.                {  

  161.                    break;  

  162.                }  

  163.                  

  164.                //公共字母的情况,因为上一个判断已经排除了结束符,所以一定是2个都不是结束符

  165.                if (Tail[head] == s.charAt(i+1))  

  166.                {  

  167.                    int avail_base = x_check(new Integer[]{GetCharCode(s.charAt(i+1))});  

  168.                    Base[cur_p] = avail_base;  

  169.                      

  170.                    Check[avail_base+GetCharCode(s.charAt(i+1))] = cur_p;  

  171.                    Base[avail_base+GetCharCode(s.charAt(i+1))] = -(head+1);  

  172.                    pre_p = cur_p;  

  173.                    continue;  

  174.                }  

  175.                else

  176.                {  

  177.                    //2个字母不相同的情况,可能有一个为结束符

  178.                    int avail_base ;  

  179.                    avail_base = x_check(new Integer[]{GetCharCode(s.charAt(i+1)),GetCharCode(Tail[head])});  

  180.                      

  181.                    Base[cur_p] = avail_base;  

  182.                      

  183.                    Check[avail_base+GetCharCode(Tail[head])] = cur_p;  

  184.                    Check[avail_base+GetCharCode(s.charAt(i+1))] = cur_p;  

  185.                      

  186.                    //Tail 为END_FLAG 的情况

  187.                    if(Tail[head] == END_CHAR)  

  188.                        Base[avail_base+GetCharCode(Tail[head])] = 0;  

  189.                    else

  190.                        Base[avail_base+GetCharCode(Tail[head])] = -(head+1);  

  191.                    if(s.charAt(i+1) == END_CHAR)  

  192.                        Base[avail_base+GetCharCode(s.charAt(i+1))] = 0;  

  193.                    else

  194.                        Base[avail_base+GetCharCode(s.charAt(i+1))] = -Pos;  

  195.                      

  196.                    Pos = CopyToTailArray(s,i+2);  

  197.                    break;  

  198.                }  

  199.            }else

  200.            //冲突2:当前结点已经被占用,需要调整pre的base

  201.            if(Check[cur_p] != pre_p)  

  202.            {  

  203.                ArrayList<Integer> list1 = GetChildList(pre_p);  

  204.                int toBeAdjust;  

  205.                ArrayList<Integer> list = null;  

  206.                if(true)  

  207.                {  

  208.                    toBeAdjust = pre_p;  

  209.                    list = list1;  

  210.                }  

  211.                  

  212.                int origin_base = Base[toBeAdjust];  

  213.                list.add(GetCharCode(s.charAt(i)));  

  214.                int avail_base = x_check((Integer[])list.toArray(new Integer[list.size()]));  

  215.                list.remove(list.size()-1);  

  216.                  

  217.                Base[toBeAdjust] = avail_base;  

  218.                for(int j=0; j<list.size(); ++j)  

  219.                {  

  220.                    //BUG  

  221.                    int tmp1 = origin_base + list.get(j);  

  222.                    int tmp2 = avail_base + list.get(j);  

  223.                      

  224.                    Base[tmp2] = Base[tmp1];  

  225.                    Check[tmp2] = Check[tmp1];  

  226.                      

  227.                    //有后续

  228.                    if(Base[tmp1] > 0)  

  229.                    {  

  230.                        ArrayList<Integer> subsequence = GetChildList(tmp1);  

  231.                        for(int k=0; k<subsequence.size(); ++k)  

  232.                        {  

  233.                            Check[Base[tmp1]+subsequence.get(k)] = tmp2;  

  234.                        }  

  235.                    }  

  236.                      

  237.                    Base[tmp1] = 0;  

  238.                    Check[tmp1] = 0;  

  239.                }  

  240.                  

  241.                //更新新的cur_p

  242.                cur_p = Base[pre_p]+GetCharCode(s.charAt(i));  

  243.                  

  244.                if(s.charAt(i) == END_CHAR)  

  245.                    Base[cur_p] = 0;  

  246.                else

  247.                    Base[cur_p] = -Pos;  

  248.                Check[cur_p] = pre_p;  

  249.                Pos = CopyToTailArray(s,i+1);  

  250.                break;  

  251.            }  

  252.        }  

  253.    }  

  254.      

  255.    public boolean Exists(String word)  

  256.    {  

  257.        int pre_p = 1;  

  258.        int cur_p = 0;  

  259.          

  260.        for(int i=0;i<word.length();++i)  

  261.        {  

  262.            cur_p = Base[pre_p]+GetCharCode(word.charAt(i));  

  263.            if(Check[cur_p] != pre_p) return false;  

  264.            if(Base[cur_p] < 0)  

  265.            {  

  266.                if(TailMatchString(-Base[cur_p],word.substring(i+1)))  

  267.                    return true;  

  268.                return false;  

  269.            }  

  270.            pre_p = cur_p;  

  271.        }  

  272.        if(Check[Base[cur_p]+GetCharCode(END_CHAR)] == cur_p)  

  273.            return true;  

  274.        return false;  

  275.    }  

  276.      

  277.    //内部函数,返回匹配单词的最靠后的Base index,

  278.    class FindStruct  

  279.    {  

  280.        int p;  

  281.        String prefix="";  

  282.    }  

  283.    private FindStruct Find(String word)  

  284.    {  

  285.        int pre_p = 1;  

  286.        int cur_p = 0;  

  287.        FindStruct fs = new FindStruct();  

  288.        for(int i=0;i<word.length();++i)  

  289.        {  

  290.            // BUG

  291.            fs.prefix += word.charAt(i);  

  292.            cur_p = Base[pre_p]+GetCharCode(word.charAt(i));  

  293.            if(Check[cur_p] != pre_p)  

  294.            {  

  295.                fs.p = -1;  

  296.                return fs;  

  297.            }  

  298.            if(Base[cur_p] < 0)  

  299.            {  

  300.                if(TailContainString(-Base[cur_p],word.substring(i+1)))  

  301.                {  

  302.                    fs.p = cur_p;  

  303.                    return fs;  

  304.                }  

  305.                fs.p = -1;  

  306.                return fs;  

  307.            }  

  308.            pre_p = cur_p;  

  309.        }  

  310.        fs.p =  cur_p;  

  311.        return fs;  

  312.    }  

  313.      

  314.    public ArrayList<String> GetAllChildWord(int index)  

  315.    {  

  316.        ArrayList<String> result = new ArrayList<String>();  

  317.        if(Base[index] == 0)  

  318.        {  

  319.            result.add("");  

  320.            return result;  

  321.        }  

  322.        if(Base[index] < 0)  

  323.        {  

  324.            String r="";  

  325.            for(int i=-Base[index];Tail[i]!=END_CHAR;++i)  

  326.            {  

  327.                r+= Tail[i];  

  328.            }  

  329.            result.add(r);  

  330.            return result;  

  331.        }  

  332.        for(int i=1;i<=CharMap.size();++i)  

  333.        {  

  334.            if(Check[Base[index]+i] == index)  

  335.            {  

  336.                for(String s:GetAllChildWord(Base[index]+i))  

  337.                {  

  338.                    result.add(CharList.get(i)+s);  

  339.                }  

  340.                //result.addAll(GetAllChildWord(Base[index]+i));

  341.            }  

  342.        }  

  343.        return result;  

  344.    }  

  345.      

  346.    public ArrayList<String> FindAllWords(String word)  

  347.    {  

  348.        ArrayList<String> result = new ArrayList<String>();  

  349.        String prefix = "";  

  350.        FindStruct fs = Find(word);  

  351.        int p = fs.p;  

  352.        if (p == -1) return result;  

  353.        if(Base[p]<0)  

  354.        {  

  355.            String r="";  

  356.            for(int i=-Base[p];Tail[i]!=END_CHAR;++i)  

  357.            {  

  358.                r+= Tail[i];  

  359.            }  

  360.            result.add(fs.prefix+r);  

  361.            return result;  

  362.        }  

  363.          

  364.        if(Base[p] > 0)  

  365.        {  

  366.            ArrayList<String> r =  GetAllChildWord(p);  

  367.            for(int i=0;i<r.size();++i)  

  368.            {  

  369.                r.set(i, fs.prefix+r.get(i));  

  370.            }  

  371.            return r;  

  372.        }  

  373.          

  374.        return result;  

  375.    }  

  376.      

  377. }

测  试

  1. import java.io.BufferedReader;  

  2. import java.io.FileInputStream;  

  3. import java.io.IOException;  

  4. import java.io.InputStream;  

  5. import java.io.InputStreamReader;  

  6. import java.util.ArrayList;  

  7. import java.util.Scanner;  

  8. import javax.xml.crypto.Data;  

  9. public class Main {  

  10.    public static void main(String[] args) throws Exception {  

  11.        ArrayList<String> words = new ArrayList<String>();  

  12.        BufferedReader reader = new BufferedReader(new InputStreamReader(new FileInputStream("E:/兔子的试验学习中心[课内]/ACM大赛/ACM第四届校赛/E命令提示/words3.dic")));  

  13.        String s;  

  14.        int num = 0;  

  15.        while((s=reader.readLine()) != null)  

  16.        {  

  17.            words.add(s);  

  18.            num ++;  

  19.        }  

  20.        DoubleArrayTrie dat = new DoubleArrayTrie();  

  21.          

  22.        for(String word: words)  

  23.        {  

  24.            dat.Insert(word);  

  25.        }  

  26.          

  27.        System.out.println(dat.Base.length);  

  28.        System.out.println(dat.Tail.length);  

  29.          

  30.        Scanner sc = new Scanner(System.in);  

  31.        while(sc.hasNext())  

  32.        {  

  33.            String word = sc.next();  

  34.            System.out.println(dat.Exists(word));  

  35.            System.out.println(dat.FindAllWords(word));  

  36.        }  

  37.          

  38.    }  

  39. }  

下面是测试结果,构造6W英文单词的DAT,大概需要20秒


我增长数组的时候是每次长度增加到2倍,初始1024

Base和Check数组的长度为131072

Tail的长度为262144


本文网址:http://www.mingyangnet.com/html/hangye/1227.html
读完这篇文章后,您心情如何?
  • 0
  • 0
  • 0
  • 0
  • 0
  • 0
  • 0
  • 0
更多>>网友评论
发表评论
编辑推荐
  • 没有资料