ThreadPoolExecutor 优先级的线程池

本文使用 ThreadPoolExecutor 实现一个 带优先级的线程池


最近做一个PPT转PDF的功能, 调用 office 的另存为, 时间较长, (大约2S转一个文件), 而且只能单线程来跑, 项目要求批量转好并发邮件, 如果用户手动点击的生成PDF则应该尽快生成, 不能等批量转好后再让用户下载.所以就实现了一个有优先级的线程池任务队列. 其实正常的实现方式是使用优先级队列(java.util.PriorityQueue / java.util.concurrent.PriorityBlockingQueue)这种方式没办法同步的获取结果, 编程上有点复杂, java.util.concurrent.ThreadPoolExecutor 可以 public <T> Future<T> submit(Callable<T> task); 使用Future.get(), 阻塞线程, 等待结果, 来实现同步调用.

public class PriorityThreadPoolExecutor extends ThreadPoolExecutor;

实现方法很简单, 继承 ThreadPoolExecutor 使用 PriorityBlockingQueue 优先级队列. PriorityBlockingQueue 有个坑就是.

Operations on this class make no guarantees about the ordering of elements with equal priority.

*如果优先级相同,不能确定顺序. *

实际测试下来的结果是, 如果优先级相同则执行顺序跟插入顺序相反, 这就尴尬了, 着还是FIFO队列吗? 官网给了解决方式.对每一个队列元素编号, 照抄就可以了. 限制就是队列历史总个数不能超过 Long 个. 实现一个Comparable 的类

class PriorityRunnable<E extends Comparable<? super E>> implements Runnable, Comparable<PriorityRunnable<E>>;

重载线程池的添加任务的方法, 追加一个参数. 如果使用基类的方法, 优先级为 0 .

    public void execute(Runnable command, int priority);
    public <T> Future<T> submit(Callable<T> task, int priority);
    public <T> Future<T> submit(Runnable task, T result, int priority);
    public Future<?> submit(Runnable task, int priority);

最终代码如下

package wang.lcs.sys.util;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicLong;

public class PriorityThreadPoolExecutor extends ThreadPoolExecutor {

    private static final Logger log = LoggerFactory.getLogger(PriorityThreadPoolExecutor.class);

    private ThreadLocal<Integer> local = new ThreadLocal<Integer>() {
        @Override
        protected Integer initialValue() {
            return 0;
        }
    };

    public PriorityThreadPoolExecutor(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit) {
        super(corePoolSize, maximumPoolSize, keepAliveTime, unit, getWorkQueue());
    }

    public PriorityThreadPoolExecutor(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, ThreadFactory threadFactory) {
        super(corePoolSize, maximumPoolSize, keepAliveTime, unit, getWorkQueue(), threadFactory);
    }

    public PriorityThreadPoolExecutor(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, RejectedExecutionHandler handler) {
        super(corePoolSize, maximumPoolSize, keepAliveTime, unit, getWorkQueue(), handler);
    }

    public PriorityThreadPoolExecutor(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, ThreadFactory threadFactory, RejectedExecutionHandler handler) {
        super(corePoolSize, maximumPoolSize, keepAliveTime, unit, getWorkQueue(), threadFactory, handler);
    }

    protected static PriorityBlockingQueue getWorkQueue() {
        return new PriorityBlockingQueue();
    }

    @Override
    public void execute(Runnable command) {
        int priority = local.get();
        try {
            this.execute(command, priority);
        } finally {
            local.set(0);
        }
    }

    public void execute(Runnable command, int priority) {
        super.execute(new PriorityRunnable(command, priority));
    }

    public <T> Future<T> submit(Callable<T> task, int priority) {
        local.set(priority);
        return super.submit(task);
    }

    public <T> Future<T> submit(Runnable task, T result, int priority) {
        local.set(priority);
        return super.submit(task, result);
    }

    public Future<?> submit(Runnable task, int priority) {
        local.set(priority);
        return super.submit(task);
    }

    protected static class PriorityRunnable<E extends Comparable<? super E>> implements Runnable, Comparable<PriorityRunnable<E>> {
        private final static AtomicLong seq = new AtomicLong();
        private final long seqNum;
        Runnable run;
        private int priority;

        public PriorityRunnable(Runnable run, int priority) {
            seqNum = seq.getAndIncrement();
            this.run = run;
            this.priority = priority;
        }

        public int getPriority() {
            return priority;
        }

        public void setPriority(int priority) {
            this.priority = priority;
        }

        public Runnable getRun() {
            return run;
        }

        @Override
        public void run() {
            this.run.run();
        }

        @Override
        public int compareTo(PriorityRunnable<E> other) {
            int res = 0;
            if (this.priority == other.priority) {
                if (other.run != this.run) {// ASC
                    res = (seqNum < other.seqNum ? -1 : 1);
                }
            } else {// DESC
                res = this.priority > other.priority ? -1 : 1;
            }
            return res;
        }
    }
}

下面是测试用例

package wang.lcs.sys.util;

import org.junit.Assert;
import org.junit.Test;

import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;

import static org.junit.Assert.*;

public class PriorityThreadPoolExecutorTest {

    @Test
    public void testDefault() throws InterruptedException, ExecutionException {
        PriorityThreadPoolExecutor pool = new PriorityThreadPoolExecutor(1, 1000, 1, TimeUnit.MINUTES);

        Future[] futures = new Future[20];
        StringBuffer buffer = new StringBuffer();
        for (int i = 0; i < futures.length; i++) {
            int index = i;
            futures[i] = pool.submit(new Callable() {
                @Override
                public Object call() throws Exception {
                    Thread.sleep(10);
                    buffer.append(index + ", ");
                    return null;
                }
            });
        }
        // 等待所有任务结束
        for (int i = 0; i < futures.length; i++) {
            futures[i].get();
        }
        System.out.println(buffer);
        assertEquals("0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, ", buffer.toString());
    }

    @Test
    public void testSamePriority() throws InterruptedException, ExecutionException {
        PriorityThreadPoolExecutor pool = new PriorityThreadPoolExecutor(1, 1000, 1, TimeUnit.MINUTES);

        Future[] futures = new Future[10];
        StringBuffer buffer = new StringBuffer();
        for (int i = 0; i < futures.length; i++) {
            futures[i] = pool.submit(new TenSecondTask(i, 1, buffer), 1);
        }
        // 等待所有任务结束
        for (int i = 0; i < futures.length; i++) {
            futures[i].get();
        }
        System.out.println(buffer);
        assertEquals("01@00, 01@01, 01@02, 01@03, 01@04, 01@05, 01@06, 01@07, 01@08, 01@09, ", buffer.toString());
    }

    @Test
    public void testRandomPriority() throws InterruptedException, ExecutionException {
        PriorityThreadPoolExecutor pool = new PriorityThreadPoolExecutor(1, 1000, 1, TimeUnit.MINUTES);

        Future[] futures = new Future[20];
        StringBuffer buffer = new StringBuffer();
        for (int i = 0; i < futures.length; i++) {
            int r = (int) (Math.random() * 100);
            futures[i] = pool.submit(new TenSecondTask(i, r, buffer), r);
        }
        // 等待所有任务结束
        for (int i = 0; i < futures.length; i++) {
            futures[i].get();
        }

        buffer.append("01@00");
        System.out.println(buffer);
        String[] split = buffer.toString().split(", ");
        // 从 2 开始, 因为前面的任务可能已经开始
        for (int i = 2; i < split.length - 1; i++) {
            String s = split[i].split("@")[0];
            assertTrue(Integer.valueOf(s) >= Integer.valueOf(split[i + 1].split("@")[0]));
        }
    }

    public static class TenSecondTask<T> implements Callable<T> {
        private StringBuffer buffer;
        int index;
        int priority;

        public TenSecondTask(int index, int priority, StringBuffer buffer) {
            this.index = index;
            this.priority = priority;
            this.buffer = buffer;
        }

        @Override
        public T call() throws Exception {
            Thread.sleep(10);
            buffer.append(String.format("%02d@%02d", this.priority, index)).append(", ");
            return null;
        }
    }
}

需要说明的是: 使用了 ThreadLocal 类, 减少代码的复制粘贴

推荐阅读更多精彩内容