本文使用 ThreadPoolExecutor 实现一个 带优先级的线程池
最近做一个PPT转PDF的功能, 调用 office 的另存为, 时间较长, (大约2S转一个文件), 而且只能单线程来跑, 项目要求批量转好并发邮件, 如果用户手动点击的生成PDF则应该尽快生成, 不能等批量转好后再让用户下载.所以就实现了一个有优先级的线程池任务队列. 其实正常的实现方式是使用优先级队列(java.util.PriorityQueue / java.util.concurrent.PriorityBlockingQueue)这种方式没办法同步的获取结果, 编程上有点复杂, java.util.concurrent.ThreadPoolExecutor 可以 public <T> Future<T> submit(Callable<T> task); 使用Future.get(), 阻塞线程, 等待结果, 来实现同步调用.
public class PriorityThreadPoolExecutor extends ThreadPoolExecutor;
实现方法很简单, 继承 ThreadPoolExecutor 使用 PriorityBlockingQueue 优先级队列. PriorityBlockingQueue 有个坑就是.
Operations on this class make no guarantees about the ordering of elements with equal priority.
*如果优先级相同,不能确定顺序. *
实际测试下来的结果是, 如果优先级相同则执行顺序跟插入顺序相反, 这就尴尬了, 着还是FIFO队列吗? 官网给了解决方式.对每一个队列元素编号, 照抄就可以了. 限制就是队列历史总个数不能超过 Long 个. 实现一个Comparable 的类
class PriorityRunnable<E extends Comparable<? super E>> implements Runnable, Comparable<PriorityRunnable<E>>;
重载线程池的添加任务的方法, 追加一个参数. 如果使用基类的方法, 优先级为 0 .
public void execute(Runnable command, int priority);
public <T> Future<T> submit(Callable<T> task, int priority);
public <T> Future<T> submit(Runnable task, T result, int priority);
public Future<?> submit(Runnable task, int priority);
最终代码如下
package wang.lcs.sys.util;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicLong;
public class PriorityThreadPoolExecutor extends ThreadPoolExecutor {
private static final Logger log = LoggerFactory.getLogger(PriorityThreadPoolExecutor.class);
private ThreadLocal<Integer> local = new ThreadLocal<Integer>() {
@Override
protected Integer initialValue() {
return 0;
}
};
public PriorityThreadPoolExecutor(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit) {
super(corePoolSize, maximumPoolSize, keepAliveTime, unit, getWorkQueue());
}
public PriorityThreadPoolExecutor(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, ThreadFactory threadFactory) {
super(corePoolSize, maximumPoolSize, keepAliveTime, unit, getWorkQueue(), threadFactory);
}
public PriorityThreadPoolExecutor(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, RejectedExecutionHandler handler) {
super(corePoolSize, maximumPoolSize, keepAliveTime, unit, getWorkQueue(), handler);
}
public PriorityThreadPoolExecutor(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, ThreadFactory threadFactory, RejectedExecutionHandler handler) {
super(corePoolSize, maximumPoolSize, keepAliveTime, unit, getWorkQueue(), threadFactory, handler);
}
protected static PriorityBlockingQueue getWorkQueue() {
return new PriorityBlockingQueue();
}
@Override
public void execute(Runnable command) {
int priority = local.get();
try {
this.execute(command, priority);
} finally {
local.set(0);
}
}
public void execute(Runnable command, int priority) {
super.execute(new PriorityRunnable(command, priority));
}
public <T> Future<T> submit(Callable<T> task, int priority) {
local.set(priority);
return super.submit(task);
}
public <T> Future<T> submit(Runnable task, T result, int priority) {
local.set(priority);
return super.submit(task, result);
}
public Future<?> submit(Runnable task, int priority) {
local.set(priority);
return super.submit(task);
}
protected static class PriorityRunnable<E extends Comparable<? super E>> implements Runnable, Comparable<PriorityRunnable<E>> {
private final static AtomicLong seq = new AtomicLong();
private final long seqNum;
Runnable run;
private int priority;
public PriorityRunnable(Runnable run, int priority) {
seqNum = seq.getAndIncrement();
this.run = run;
this.priority = priority;
}
public int getPriority() {
return priority;
}
public void setPriority(int priority) {
this.priority = priority;
}
public Runnable getRun() {
return run;
}
@Override
public void run() {
this.run.run();
}
@Override
public int compareTo(PriorityRunnable<E> other) {
int res = 0;
if (this.priority == other.priority) {
if (other.run != this.run) {// ASC
res = (seqNum < other.seqNum ? -1 : 1);
}
} else {// DESC
res = this.priority > other.priority ? -1 : 1;
}
return res;
}
}
}
下面是测试用例
package wang.lcs.sys.util;
import org.junit.Assert;
import org.junit.Test;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import static org.junit.Assert.*;
public class PriorityThreadPoolExecutorTest {
@Test
public void testDefault() throws InterruptedException, ExecutionException {
PriorityThreadPoolExecutor pool = new PriorityThreadPoolExecutor(1, 1000, 1, TimeUnit.MINUTES);
Future[] futures = new Future[20];
StringBuffer buffer = new StringBuffer();
for (int i = 0; i < futures.length; i++) {
int index = i;
futures[i] = pool.submit(new Callable() {
@Override
public Object call() throws Exception {
Thread.sleep(10);
buffer.append(index + ", ");
return null;
}
});
}
// 等待所有任务结束
for (int i = 0; i < futures.length; i++) {
futures[i].get();
}
System.out.println(buffer);
assertEquals("0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, ", buffer.toString());
}
@Test
public void testSamePriority() throws InterruptedException, ExecutionException {
PriorityThreadPoolExecutor pool = new PriorityThreadPoolExecutor(1, 1000, 1, TimeUnit.MINUTES);
Future[] futures = new Future[10];
StringBuffer buffer = new StringBuffer();
for (int i = 0; i < futures.length; i++) {
futures[i] = pool.submit(new TenSecondTask(i, 1, buffer), 1);
}
// 等待所有任务结束
for (int i = 0; i < futures.length; i++) {
futures[i].get();
}
System.out.println(buffer);
assertEquals("01@00, 01@01, 01@02, 01@03, 01@04, 01@05, 01@06, 01@07, 01@08, 01@09, ", buffer.toString());
}
@Test
public void testRandomPriority() throws InterruptedException, ExecutionException {
PriorityThreadPoolExecutor pool = new PriorityThreadPoolExecutor(1, 1000, 1, TimeUnit.MINUTES);
Future[] futures = new Future[20];
StringBuffer buffer = new StringBuffer();
for (int i = 0; i < futures.length; i++) {
int r = (int) (Math.random() * 100);
futures[i] = pool.submit(new TenSecondTask(i, r, buffer), r);
}
// 等待所有任务结束
for (int i = 0; i < futures.length; i++) {
futures[i].get();
}
buffer.append("01@00");
System.out.println(buffer);
String[] split = buffer.toString().split(", ");
// 从 2 开始, 因为前面的任务可能已经开始
for (int i = 2; i < split.length - 1; i++) {
String s = split[i].split("@")[0];
assertTrue(Integer.valueOf(s) >= Integer.valueOf(split[i + 1].split("@")[0]));
}
}
public static class TenSecondTask<T> implements Callable<T> {
private StringBuffer buffer;
int index;
int priority;
public TenSecondTask(int index, int priority, StringBuffer buffer) {
this.index = index;
this.priority = priority;
this.buffer = buffer;
}
@Override
public T call() throws Exception {
Thread.sleep(10);
buffer.append(String.format("%02d@%02d", this.priority, index)).append(", ");
return null;
}
}
}
需要说明的是: 使用了 ThreadLocal 类, 减少代码的复制粘贴