数据结构:前缀树Trie

引子

在刷题的过程中,经常会遇到这样一种典型问题:

给一组字符串List<String> strs,找出其中前缀为String p的所有字符串。

朴素的做法就是遍历strs,然后每一个看一下是不是有前缀为p,这样的效率是O(N*L),其中N是strs中字符串的个数,L是p的长度。
这样看来,其实也不差。就单次操作而言,这些操作都是必须的。问题在于,当出现多次这种操作的时候,每次都需要重新判断前缀,这样效率就不高了。
不卖关子了,相信大家也知道什么数据结构好用了——就是今天要介绍的前缀树(或者叫字典树,又称单词查找树或键树),英文叫做Trie,专门用来处理这种操作。

介绍

借用百度百科的图:


Trie

最简单的Trie的话,只需要保持从当前节点到子节点的映射就行了,子节点的字符可以从当前节点得到:

import java.util.HashMap;
import java.util.Map;

class TrieNode {

    private Map<Character, TrieNode> children = new HashMap<Character, TrieNode>();

    public TrieNode() {
    }

    public Map<Character, TrieNode> getChildren() {
        return children;
    }

    public void setChildren(HashMap<Character, TrieNode> children) {
        this.children = children;
    }
}

在此基础上,我们可以根据需要来加入新的变量:

  • 可以用一个Boolean来表示当前节点是不是某个字符串的末尾,从而在路径重合的时候知道有哪些字符串;
  • 可以用一个Int来表示当前节点下面完整字符串的个数,用于快速计数某前缀的个数;
  • 可以用一个List来存储当前节点下面的完整字符串,用于快速返回含有某个前缀的所有字符串。

可以列举的还有很多,总之Trie的一大优势就是灵活,按照具体需要进行修改。

那么以最简单的Trie为基础,如何实现基本的增删查改操作?

  • 增的话,比较简单,基本上是“顺藤摸瓜”,有就一直跟着走,什么地方断了就自己新建TrieNode。时间复杂度O(n),n是增加的字符串的长度:
    public static void add(String word, TrieNode root) {
        TrieNode cur = root;
        for (char ch : word.toCharArray()) {
            if (cur.getChildren().containsKey(ch)) {
                cur = cur.getChildren().get(ch);
            } else {
                TrieNode node = new TrieNode();
                cur.getChildren().put(ch, node);
                cur = node;
            }
        }
    }
  • 查也不难,思路和增类似,只是不需要自己创建新节点而是直接返回false。时间复杂度还是O(n):
    public static boolean hasPrefix(String prefix, TrieNode root) {
        TrieNode cur = root;
        for (char ch : prefix.toCharArray()) {
            if (cur.getChildren().containsKey(ch)) {
                cur = cur.getChildren().get(ch);
            } else {
                return false;
            }
        }
        return true;
    }
  • 删的话相对复杂一些,一个例子就是两个完全一样的字符串,只删一个怎么实现。上面写的数据结构是不支持重复字符串的,因此还需要引入一个新变量来存储对应字符串的个数。这样就复杂许多了。
    因此,这里的实现假设字符串都是不重复的。由于搜索的顺序和删的顺序其实是相反的,因此递归是不错的选择。删无非2种情况,需要删除本节点,和保留本节点给其他字符串,所以递归还是需要一个返回值的。复杂度还是O(n)。
    public static void remove(String word, TrieNode root) {
        removeHelper(word, root, 0);
    }

    private static boolean removeHelper(String word, TrieNode node, int idx) {
        if (idx != word.length()) {
            char ch = word.charAt(idx);
            TrieNode next = node.getChildren().get(ch);
            if (removeHelper(word, next, idx + 1)) {
                node.getChildren().remove(ch);
            }
        }
        return node.getChildren().isEmpty();
    }
  • 改的话,就可以转化成增和删的组合。

回顾

回到刚开始尝试解决的问题,假如要找到所有前缀为p的字符串,那么就是遍历到前缀p的最后一个节点,之后所有标记为完整字符串的就都是了。极限情况下,前缀p为空,那么相当于要遍历整棵树,而整棵树最坏情况就是所有字符串都不重合,那么也就是说复杂度是O(N*S),N是字符串个数,S是字符串的平均长度。可见此时并不优越。
那么怎么办呢?Trie是可以根据需要灵活改变的。在原有Trie的基础上,执行要求的操作效率不高的原因在于,就算确定了前缀的节点,后面的字符串还是要一个个去遍历出来,从而效率不高。很直接的解决思想就是加一个List来维持节点所对应的字符串。
这样改变之后,操作的复杂度就是O(p),p为前缀p的长度,因为只需要找到节点读一下List即可。空间复杂度的话,其实虽然同样的字符串在很多节点出现,但其实都指向同一个对象,其实没有想象的那么糟糕。或者还可以存字符串List里原来的index。完整代码如下:

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

class TrieNode {

    private Map<Character, TrieNode> children = new HashMap<Character, TrieNode>();
    private List<String> words = new ArrayList<>();

    public List<String> getWords() {
        return words;
    }

    public void setWords(List<String> words) {
        this.words = words;
    }

    public TrieNode() {
    }

    public Map<Character, TrieNode> getChildren() {
        return children;
    }

    public void setChildren(HashMap<Character, TrieNode> children) {
        this.children = children;
    }

    public static void add(String word, TrieNode root) {
        root.words.add(word);
        TrieNode cur = root;
        for (char ch : word.toCharArray()) {
            if (cur.getChildren().containsKey(ch)) {
                cur = cur.getChildren().get(ch);
            } else {
                TrieNode node = new TrieNode();
                cur.getChildren().put(ch, node);
                cur = node;
            }
            cur.words.add(word);
        }
    }

    public static List<String> getWordsWithPrefix(String prefix, TrieNode root) {
        TrieNode cur = root;
        for (char ch : prefix.toCharArray()) {
            if (cur.getChildren().containsKey(ch)) {
                cur = cur.getChildren().get(ch);
            } else {
                return new ArrayList<>();
            }
        }
        return cur.words;
    }
}

运用

下面看一个例题,力扣的Word Squares:



值得注意的一点是字符串都是可以重复使用的。

解法与思路

形成Word Squares要求第k行和第k列一致。
第一行的单词没有任何限制。假设是wall。
第二行,k=0,那么就是说第0列已经确定了,也必须是wall,那么第二行必须以a开始。假设是area。
第三行,同样k=0,k=1都限制了前缀必须是le。假设是lead。
第四行,限制了前缀必须是lad。
由此可见,每一行都会增加一个限制,前缀必须为第k列的开头。
因此,一个直接的思路是DFS加back tracking:

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

class Solution {

    public List<List<String>> wordSquares(String[] words) {
        Set<List<String>> ret = new HashSet<>();
        if (words.length == 0) return new ArrayList<>(ret);
        recur(words, new ArrayList<>(), ret);
        return new ArrayList<>(ret);
    }

    private void recur(String[] words, List<String> temp, Set<List<String>> ret) {
        if (temp.size() == words[0].length()) {
            ret.add(new ArrayList<>(temp));
            return;
        }
        for (String w : words) {
            if (meetRequirements(w, temp)) {
                temp.add(w);
                recur(words, temp, ret);
                temp.remove(temp.size() - 1);
            }
        }
    }

    private boolean meetRequirements(String w, List<String> temp) {
        int cnt = temp.size();
        if (cnt == 0) return true;
        for (int i = 0; i < cnt; i++) {
            if (w.charAt(i) != temp.get(i).charAt(cnt)) return false;
        }
        return true;
    }
}

很可惜,这个解法TLE超时。为何?因为求下一个单词时,需要过一遍整个单词列表来判断哪些单词满足条件,效率不高。
这时候,我们之前学的Trie总算派上用场了:可以建一个Trie树,然后需要找前缀的时候在里面搜就好。完整代码如下(Leetcode AC):

import java.util.*;

class Solution {

    public List<List<String>> wordSquares(String[] words) {
        Set<List<String>> ret = new HashSet<>();
        if (words.length == 0) return new ArrayList<>(ret);
        TrieNode root = buildTrie(words);
        recur(words, new ArrayList<>(), ret, root);
        return new ArrayList<>(ret);
    }

    private TrieNode buildTrie(String[] words) {
        TrieNode root = new TrieNode();
        for (String s : words) {
            TrieNode.add(s, root);
        }
        return root;
    }

    private void recur(String[] words, List<String> temp, Set<List<String>> ret, TrieNode root) {
        if (temp.size() == words[0].length()) {
            ret.add(new ArrayList<>(temp));
            return;
        }
        // get prefix
        StringBuilder sb = new StringBuilder(temp.size());
        for (int i = 0; i < temp.size(); i++) {
            sb.append(temp.get(i).charAt(temp.size()));
        }
        for (String w : TrieNode.getWordsWithPrefix(sb.toString(), root)) {
            temp.add(w);
            recur(words, temp, ret, root);
            temp.remove(temp.size() - 1);
        }
    }

    private static class TrieNode {

        private Map<Character, TrieNode> children = new HashMap<Character, TrieNode>();
        private final List<String> words = new ArrayList<>();

        public TrieNode() {
        }

        public Map<Character, TrieNode> getChildren() {
            return children;
        }

        public void setChildren(HashMap<Character, TrieNode> children) {
            this.children = children;
        }

        public static void add(String word, TrieNode root) {
            root.words.add(word);
            TrieNode cur = root;
            for (char ch : word.toCharArray()) {
                if (cur.getChildren().containsKey(ch)) {
                    cur = cur.getChildren().get(ch);
                } else {
                    TrieNode node = new TrieNode();
                    cur.getChildren().put(ch, node);
                    cur = node;
                }
                cur.words.add(word);
            }
        }

        public static List<String> getWordsWithPrefix(String prefix, TrieNode root) {
            TrieNode cur = root;
            for (char ch : prefix.toCharArray()) {
                if (cur.getChildren().containsKey(ch)) {
                    cur = cur.getChildren().get(ch);
                } else {
                    return new ArrayList<>();
                }
            }
            return cur.words;
        }
    }
}