简介
为什么需要多线程
项目里经常会遇到这样的场景:
- 读一次数据库的某个表
- 遍历这些数据,对每一个数据,都以它为条件再次查其他表
- 将第二步查到的数据组合起来,返回给前端
那么问题来了,数据量大了之后,会很费时间,而且在循环里边操作数据库是非常低效的。这点也在《阿里巴巴Java开发规范手册》里边:“循环体中的语句要考量性能,以下操作尽量移至循环体外处理,如定义对象、变量、 获取数据库连接,进行不必要的 try-catch 操作(这个 try-catch 是否可以移至循环体外)。”。
如果业务需求就只能在循环里操作数据库,那么最好的解决方法就是使用多线程处理,本文介绍如何处理。
实际场景
需求
假设有一个接口,获得所有订单的信息,按创建时间由近到远排序。
数据库表结构
- 订单表的字段:订单id、用户id、总金额、商品id、商品的数量、创建时间 等。
- 用户表的字段:用户id、用户名字、用户的电话、注册时间 等。
- 商品表的字段:商品id、商品名字 等。
业务流程
- 读出订单表前n行数据
- 用第1步的每个数据的用户id去用户表获取用户的信息
- 将第2步获取到的用户信息整合到订单VO里。
说明
本处为了简单,用这种方法处理:
- 仅获取订单的id、用户id、用户名字。
- 使用模拟数据,不实际操作数据库。用sleep()模拟对数据库的操作。
- 所有代码都在controller层处理。(实际业务中肯定要放到service中处理)。
思路
整体方案
- 多线程方案:使用线程池。
- 等待多线程执行完毕的方案:CoutDownLatch。
- 如何排序:
- 方案1:所有线程处理完之后统一排序。(此法简单,本文使用此方法)
- 方案2:使用队列,前边的线程将结果放入最终结果集后,唤醒下一个线程将结果放入结果集。(较为复杂)。
(搜索:“顺序消费实例”)
细节
- 线程池的核心线程数、最大线程数如何设置。
- 见: 搜索:“线程池个数设置”
- 线程安全的List。
- 见
代码
公共代码
controller
package com.example.order.controller;
import com.example.order.entity.OrderVO;
import com.example.order.entity.User;
import com.example.order.task.OrderTask;
import com.example.order.task.OrderTask2;
import com.example.utils.SynchroniseUtil;
import com.example.utils.ThreadPoolExecutors;
import com.example.utils.ThreadPoolExecutors2;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import javax.annotation.PostConstruct;
import java.time.LocalDateTime;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
public class OrderController {
private List<OrderVO> orderVOS = new ArrayList<>();
private List<User> users = new ArrayList<>();
//初始化时就创建好数据。模拟数据库已经存在的数据
public void createData() {
long dataCount = 500;
// 创建订单数据。模拟已经插入到数据库的订单
for (long i = 0; i < dataCount; i++) {
OrderVO orderVO = new OrderVO();
orderVO.setId(i + 1);
orderVO.setUserId(i + 1);
//防止电脑太快,导致都是同一个时间,所以加一个数
orderVO.setCreateTime(LocalDateTime.now().plusSeconds(i));
orderVOS.add(orderVO);
}
// 创建用户数据。模拟已经插入到数据库的用户
for (long i = 0; i < dataCount; i++) {
User user = new User();
user.setId(i + 1);
user.setUserName("用户名" + (i + 1));
users.add(user);
}
orderVOS = orderVOS.stream()
.sorted(Comparator.comparing(OrderVO::getCreateTime).reversed())
.collect(Collectors.toList());
}
("/getOrderDetails")
public List<OrderVO> getOrderDetails() {
long startTime = System.currentTimeMillis();
List<OrderVO> orderVOList;
//这里是不同的执行方式(单线程/线程池)
long endTime = System.currentTimeMillis();
System.out.println("执行时间:" + (endTime - startTime) + " ms");
return orderVOList;
}
}
entity
订单
package com.example.order.entity;
import lombok.Data;
import java.time.LocalDateTime;
public class Order {
private Long id;
private Long userId;
private LocalDateTime createTime;
}
订单视图(用于返回数据)
package com.example.order.entity;
import lombok.Data;
import lombok.EqualsAndHashCode;
(callSuper = true)
public class OrderVO extends Order{
private String userName;
}
用户
package com.example.order.entity;
import lombok.Data;
public class User {
private Long id;
private String userName;
}
方案1:单线程
代码
package com.example.order.controller;
import com.example.order.entity.OrderVO;
import com.example.order.entity.User;
import com.example.order.task.OrderTask;
import com.example.order.task.OrderTask2;
import com.example.utils.SynchroniseUtil;
import com.example.utils.ThreadPoolExecutors;
import com.example.utils.ThreadPoolExecutors2;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import javax.annotation.PostConstruct;
import java.time.LocalDateTime;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
public class OrderController {
private List<OrderVO> orderVOS = new ArrayList<>();
private List<User> users = new ArrayList<>();
//初始化时就创建好数据。模拟数据库已经存在的数据
public void createData() {
long dataCount = 500;
// 创建订单数据。模拟已经插入到数据库的订单
for (long i = 0; i < dataCount; i++) {
OrderVO orderVO = new OrderVO();
orderVO.setId(i + 1);
orderVO.setUserId(i + 1);
//防止电脑太快,导致都是同一个时间,所以加一个数
orderVO.setCreateTime(LocalDateTime.now().plusSeconds(i));
orderVOS.add(orderVO);
}
// 创建用户数据。模拟已经插入到数据库的用户
for (long i = 0; i < dataCount; i++) {
User user = new User();
user.setId(i + 1);
user.setUserName("用户名" + (i + 1));
users.add(user);
}
orderVOS = orderVOS.stream()
.sorted(Comparator.comparing(OrderVO::getCreateTime).reversed())
.collect(Collectors.toList());
}
("/getOrderDetails")
public List<OrderVO> getOrderDetails() {
long startTime = System.currentTimeMillis();
List<OrderVO> orderVOList;
orderVOList = singleThread(orderVOS);
long endTime = System.currentTimeMillis();
System.out.println("执行时间:" + (endTime - startTime) + " ms");
return orderVOList;
}
private List<OrderVO> singleThread(List<OrderVO> orders) {
List<OrderVO> result = new ArrayList<>(orders);
for (OrderVO orderVO : result) {
//模拟从数据库里查数据
try {
Thread.sleep(10);
} catch (InterruptedException e) {
e.printStackTrace();
}
for (User user : users) {
if (orderVO.getUserId().equals(user.getId())) {
orderVO.setUserName(user.getUserName());
break;
}
}
}
return result;
}
}
测试
请求:http://localhost:8080/getOrderDetails
后端打印
执行时间:7525 ms
前端结果:
总结
缺点
- 才500个数据,就用了7秒多,实在太慢。
方案2:线程池(每个数据一个任务)
简介
上边已经看到了,单线程特别慢,本处使用线程池来优化:每个数据一个任务。
controller
package com.example.order.controller;
import com.example.order.entity.OrderVO;
import com.example.order.entity.User;
import com.example.order.task.OrderTask;
import com.example.order.task.OrderTask2;
import com.example.utils.SynchroniseUtil;
import com.example.utils.ThreadPoolExecutors;
import com.example.utils.ThreadPoolExecutors2;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import javax.annotation.PostConstruct;
import java.time.LocalDateTime;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
public class OrderController {
private List<OrderVO> orderVOS = new ArrayList<>();
private List<User> users = new ArrayList<>();
//初始化时就创建好数据。模拟数据库已经存在的数据
public void createData() {
long dataCount = 500;
// 创建订单数据。模拟已经插入到数据库的订单
for (long i = 0; i < dataCount; i++) {
OrderVO orderVO = new OrderVO();
orderVO.setId(i + 1);
orderVO.setUserId(i + 1);
//防止电脑太快,导致都是同一个时间,所以加一个数
orderVO.setCreateTime(LocalDateTime.now().plusSeconds(i));
orderVOS.add(orderVO);
}
// 创建用户数据。模拟已经插入到数据库的用户
for (long i = 0; i < dataCount; i++) {
User user = new User();
user.setId(i + 1);
user.setUserName("用户名" + (i + 1));
users.add(user);
}
orderVOS = orderVOS.stream()
.sorted(Comparator.comparing(OrderVO::getCreateTime).reversed())
.collect(Collectors.toList());
}
("/getOrderDetails")
public List<OrderVO> getOrderDetails() throws Exception{
long startTime = System.currentTimeMillis();
List<OrderVO> orderVOList;
orderVOList = multiThread(orderVOS);
long endTime = System.currentTimeMillis();
System.out.println("执行时间:" + (endTime - startTime) + " ms");
return orderVOList;
}
private List<OrderVO> multiThread(List<OrderVO> orders) throws Exception{
ExecutorService executor = ThreadPoolExecutors.getSingletonExecutor();
SynchroniseUtil<OrderVO> synchroniseUtil = new SynchroniseUtil<>(orders.size());
System.out.println("任务个数:" + orders.size());
for (OrderVO order : orders) {
OrderTask orderTask = new OrderTask(order, users, synchroniseUtil);
executor.execute(orderTask);
}
List<OrderVO> list = null;
try {
list = synchroniseUtil.get(10, TimeUnit.SECONDS);
} catch (InterruptedException e) {
e.printStackTrace();
}
if (list != null) {
list = list.stream()
.sorted(Comparator.comparing(OrderVO::getCreateTime).reversed())
.collect(Collectors.toList());
}
return list;
}
}
自定义Task
package com.example.order.task;
import com.example.order.entity.OrderVO;
import com.example.order.entity.User;
import com.example.utils.SynchroniseUtil;
import java.util.List;
public class OrderTask implements Runnable {
private OrderVO orderVO;
private List<User> users;
private SynchroniseUtil<OrderVO> synchroniseUtil;
public OrderTask(OrderVO orderVO,
List<User> users,
SynchroniseUtil<OrderVO> synchroniseUtil) {
this.orderVO = orderVO;
this.users = users;
this.synchroniseUtil = synchroniseUtil;
}
public void run() {
//模拟从数据库里查数据
try {
Thread.sleep(10);
} catch (InterruptedException e) {
e.printStackTrace();
}
for (User user : users) {
if (orderVO.getUserId().equals(user.getId())) {
orderVO.setUserName(user.getUserName());
break;
}
}
synchroniseUtil.addResult(orderVO);
}
}
单例模式的线程池
package com.example.utils;
import java.util.concurrent.*;
public class ThreadPoolExecutors {
private static final int processorNumber =
Runtime.getRuntime().availableProcessors();
private static class ThreadPoolExecutorsHolder {
private static final ExecutorService EXECUTOR =
Executors.newFixedThreadPool(processorNumber);
}
private ThreadPoolExecutors() {
}
public static ExecutorService getSingletonExecutor() {
System.out.println("处理器数量:" + processorNumber);
return ThreadPoolExecutorsHolder.EXECUTOR;
}
}
封装CoutDownLatch
package com.example.utils;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
public class SynchroniseUtil<T>{
private CountDownLatch countDownLatch;
private final List<T> result = Collections.synchronizedList(new ArrayList<>());
public SynchroniseUtil(int count) {
this.countDownLatch = new CountDownLatch(count);
}
public List<T> get() throws InterruptedException{
countDownLatch.await();
return this.result;
}
public List<T> get(long timeout, TimeUnit timeUnit) throws Exception{
if (countDownLatch.await(timeout, timeUnit)) {
return this.result;
} else {
throw new RuntimeException("超时");
}
}
public void addResult(T resultMember) {
result.add(resultMember);
countDownLatch.countDown();
}
public void addResult(List<T> resultMembers) {
result.addAll(resultMembers);
countDownLatch.countDown();
}
}
测试
访问:http://localhost:8080/getOrderDetails
后端结果
处理器数量:25
任务个数:500
执行时间:301 ms
前端结果
总结
优点
- 比单线程快很多。
缺点
- 固定线程池大小的线程池,队列长度是整型数的最大值,若数据很多,每个数据一个任务,会把内存耗尽。
方案3:线程池(多个数据一个任务)
简介
上边每个数据一个任务是不合适的,本处进行优化:多个数据一个任务。
controller
package com.example.order.controller;
import com.example.order.entity.OrderVO;
import com.example.order.entity.User;
import com.example.order.task.OrderTask;
import com.example.utils.SynchroniseUtil;
import com.example.utils.ThreadPoolExecutors;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import javax.annotation.PostConstruct;
import java.time.LocalDateTime;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
public class OrderController {
private List<OrderVO> orderVOS = new ArrayList<>();
private List<User> users = new ArrayList<>();
//初始化时就创建好数据。模拟数据库已经存在的数据
public void createData() {
long dataCount = 500;
// 创建订单数据。模拟已经插入到数据库的订单
for (long i = 0; i < dataCount; i++) {
OrderVO orderVO = new OrderVO();
orderVO.setId(i + 1);
orderVO.setUserId(i + 1);
//防止电脑太快,导致都是同一个时间,所以加一个数
orderVO.setCreateTime(LocalDateTime.now().plusSeconds(i));
orderVOS.add(orderVO);
}
// 创建用户数据。模拟已经插入到数据库的用户
for (long i = 0; i < dataCount; i++) {
User user = new User();
user.setId(i + 1);
user.setUserName("用户名" + (i + 1));
users.add(user);
}
orderVOS = orderVOS.stream()
.sorted(Comparator.comparing(OrderVO::getCreateTime).reversed())
.collect(Collectors.toList());
}
("/getOrderDetails")
public List<OrderVO> getOrderDetails() throws Exception{
long startTime = System.currentTimeMillis();
List<OrderVO> orderVOList;
orderVOList = multiThread(orderVOS);
long endTime = System.currentTimeMillis();
System.out.println("执行时间:" + (endTime - startTime) + " ms");
return orderVOList;
}
private List<OrderVO> multiThread(List<OrderVO> orders) throws Exception{
ThreadPoolExecutor executor = ThreadPoolExecutors.getSingletonExecutor();
int unitLength = orders.size() / ThreadPoolExecutors.getQueueSize() + 1;
int synchroniseCount = orders.size() / unitLength;
synchroniseCount = orders.size() % unitLength == 0
? synchroniseCount : synchroniseCount + 1;
SynchroniseUtil<OrderVO> synchroniseUtil = new SynchroniseUtil<>(synchroniseCount);
System.out.println("任务个数:" + synchroniseCount);
for (int i = 0; i < orders.size(); i += unitLength) {
int toIndex = Math.min(i + unitLength, orders.size() - 1);
List<OrderVO> orderVOSubList = orders.subList(i, toIndex);
OrderTask orderTask = new OrderTask(orderVOSubList, users, synchroniseUtil);
executor.execute(orderTask);
}
List<OrderVO> list = null;
try {
list = synchroniseUtil.get(10, TimeUnit.SECONDS);
} catch (InterruptedException e) {
e.printStackTrace();
return null;
}
if (list != null) {
list = list.stream()
.sorted(Comparator.comparing(OrderVO::getCreateTime).reversed())
.collect(Collectors.toList());
}
return list;
}
}
自定义Task
package com.example.order.task;
import com.example.order.entity.OrderVO;
import com.example.order.entity.User;
import com.example.utils.SynchroniseUtil;
import java.util.List;
public class OrderTask implements Runnable {
private List<OrderVO> orderVOS;
private List<User> users;
private SynchroniseUtil<OrderVO> synchroniseUtil;
public OrderTask(List<OrderVO> orderVOS,
List<User> users,
SynchroniseUtil<OrderVO> synchroniseUtil) {
this.orderVOS = orderVOS;
this.users = users;
this.synchroniseUtil = synchroniseUtil;
}
public void run() {
//模拟从数据库里查数据
try {
Thread.sleep(10);
} catch (InterruptedException e) {
e.printStackTrace();
}
for (OrderVO orderVO : orderVOS) {
for (User user : users) {
if (orderVO.getUserId().equals(user.getId())) {
orderVO.setUserName(user.getUserName());
break;
}
}
}
synchroniseUtil.addResult(orderVOS);
}
}
单例的线程池(指定队列长度)
package com.example.utils;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
public class ThreadPoolExecutors {
private static final int processorNumber =
Runtime.getRuntime().availableProcessors();
private static final int corePoolSize = processorNumber;
private static final int maximumPoolSize = processorNumber * 2 + 1;
private static final int queueSize = 100;
private static class ThreadPoolExecutorsHolder {
private static final ThreadPoolExecutor INSTANCE =
new ThreadPoolExecutor(corePoolSize, maximumPoolSize,
200,TimeUnit.MILLISECONDS,
new LinkedBlockingDeque<>(queueSize));
}
private ThreadPoolExecutors() {
}
public static ThreadPoolExecutor getSingletonExecutor() {
System.out.println("处理器数量:" + processorNumber);
return ThreadPoolExecutorsHolder.INSTANCE;
}
public static int getQueueSize() {
return queueSize;
}
}
封装CountDownLatch
package com.example.utils;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
public class SynchroniseUtil<T>{
private CountDownLatch countDownLatch;
private final List<T> result = Collections.synchronizedList(new ArrayList<>());
public SynchroniseUtil(int count) {
this.countDownLatch = new CountDownLatch(count);
}
public List<T> get() throws InterruptedException{
countDownLatch.await();
return this.result;
}
public List<T> get(long timeout, TimeUnit timeUnit) throws Exception{
if (countDownLatch.await(timeout, timeUnit)) {
return this.result;
} else {
throw new RuntimeException("超时");
}
}
public void addResult(T resultMember) {
result.add(resultMember);
countDownLatch.countDown();
}
public void addResult(List<T> resultMembers) {
result.addAll(resultMembers);
countDownLatch.countDown();
}
}
测试
访问:http://localhost:8080/getOrderDetails
后端结果
处理器数量:12
任务个数:84
执行时间:117 ms
前端结果
总结
优点
可见,此时速度比每个数据一个任务更快。(原因待分析,猜测:任务越少,在某个调度、唤醒之类的地方耗时就少,于是速度更快)
大量测试
数据量 | 单线程 | 线程池(每个数据一个任务) | 线程池(多个数据一个任务) |
100 | 1498 ms | 72 ms | 77 ms |
500 | 7525 ms | 312 ms | 113 ms |
1000 | 15024 ms | 605 ms | 125 ms |
5000 | 75160 ms | 3008 ms | 163 ms |
总结
可见,多个数据一个任务速度最快。
多个数据一个任务时,随着数据成倍的增加,耗时却没有成倍增加。原因分析:与线程池一个任务包含的数据量有关系。因为我是固定死了队列的长度,然后把总数据量平均分配到每一个队列上,如果数据量成倍增加,平均到一个任务里边,就增加的很少了。
当然,实际上任务并不是平均分到了队列里边,因为任务进来,先去占用核心线程,再去占用队列,再去占用最大线程数,见:(搜索:“线程池流程”)。按我本篇程序里的写法,实际队列并不会占满,而且最大线程数也没有用完。