0
点赞
收藏
分享

微信扫一扫

【JUC的LOCK框架系列三】工具类之CountDownLatch

代码小姐 2022-04-26 阅读 37
java

CountDownLatch

文章目录


一种并发流程控制的工具类。

主要成员

构造函数

接收参数count,代表可以有几个线程持有共享锁。不能小于0。

    public CountDownLatch(int count) {
        if (count < 0) throw new IllegalArgumentException("count < 0");
        this.sync = new Sync(count);
    }

成员变量

/*同步队列*/
private final Sync sync;

内部类Sync

Sync继承了AQS,并实现了共享锁的获取与释放相关方法。同步控制依赖AQS。详细见AQS的分析.

    private static final class Sync extends AbstractQueuedSynchronizer {
        private static final long serialVersionUID = 4982264981922014374L;

        // 构造器
        Sync(int count) {
            setState(count);
        }
        
        // 返回当前共享变量的值
        int getCount() {
            return getState();
        }

        // 尝试获取共享式锁
        protected int tryAcquireShared(int acquires) {
            return (getState() == 0) ? 1 : -1;
        }

        // 尝试释放共享式锁
        protected boolean tryReleaseShared(int releases) {
            // Decrement count; signal when transition to zero
            for (;;) {
                int c = getState();
                // 没有锁需要释放
                if (c == 0)
                    return false;
                int nextc = c-1;
                // CAS操作
                if (compareAndSetState(c, nextc))
                    return nextc == 0;
            }
        }
    }

核心方法

方法名描述
void await() throws InterruptedException1.当前锁的个数为0,则立即返回
2.如果当前锁的个数大于0,当前线程会休眠直到发生以下两种情况之一:
2.1当前线程被其他线程中断
2.2锁的计数变为0
如果当前线程在进入此方法之前,线程中断标记为true,或者在等待时中断,则抛出{@link InterruptedException}并清除当前线程的中断状态。
boolean await(long timeout, TimeUnit unit) throws InterruptedException限时的等待锁,timeout为最大等待时长。unit时间单位
void countDown()减少锁存器的计数,如果计数达到零,则释放所有等待线程。
如果当前计数大于零,则递减。 如果新计数为零,则重新启用所有等待线程以进行线程调度。
如果当前计数为零,则不会发生任何事情。
long getCount()返回现在的锁计数

AQS共享状态state的变化

tryAcquireShared子类实现判断tryAcquireShared返回值tryAcquireShared返回值含义await流程
state==01获取共享锁成功,并且后续获取也可能获取成功返回,不阻塞
-0获取共享锁成功,但后续获取可能不会成功返回,不阻塞
state!=0-1获取共享锁失败阻塞

示例

使用场景一

第一个是启动信号,它阻止任何工作线程(Worker)继续执行,直到驱动程序(Driver)准备好让他们继续前进; 第二个是完成信号,允许驱动程序(Driver)等待所有工作线程(Worker)完成。

import java.util.concurrent.CountDownLatch;

public class Driver {
    
    static class Worker implements Runnable {

        private final CountDownLatch startSignal;
        private final CountDownLatch doneSignal;

        Worker(CountDownLatch startSignal, CountDownLatch doneSignal) {
            this.startSignal = startSignal;
            this.doneSignal = doneSignal;
        }

        @Override
        public void run() {
            try {
                // 由于开始信号未释放(startSignal对应的state值还为1),工作线程阻塞
                startSignal.await();
                doWork();
            } catch (InterruptedException ex) {
                System.out.println(Thread.currentThread().getName() + "is interrupted");
            } // return;
            finally {
                doneSignal.countDown();
            }
        }

        void doWork() throws InterruptedException {
            // 休眠0.5s 代表业务运行时间,
            Thread.sleep(500);
            System.out.println(Thread.currentThread().getName() + " startSignal.count:" + startSignal.getCount());
            System.out.println(Thread.currentThread().getName() + " doneSignal.count:" + doneSignal.getCount());
        }
    }

    public static void main(String[] args) throws InterruptedException {
        // 定义两个计数器 开始信号startSignal计数为1 结束信号doneSignal计数为2
        CountDownLatch startSignal = new CountDownLatch(1);
        CountDownLatch doneSignal = new CountDownLatch(5);

        // 创建5个线程并使其在就绪状态并开始执行,
        for (int i = 0; i < 5; ++i)
        {
            new Thread(new Worker(startSignal, doneSignal)).start();
        }
        //此时信号的值
        System.out.println(Thread.currentThread().getName() + " startSignal.count:" + startSignal.getCount());
        System.out.println(Thread.currentThread().getName() + " doneSignal.count:" + doneSignal.getCount());
        // 此方法执行之后,开始信号对应的状态值就不在为0,而为1。代表着工作线程可以执行了。结束信号对应的状态值还是5,但是工作线程开始运行后会递减。
        startSignal.countDown();
        System.out.println(Thread.currentThread().getName() + " startSignal.count:" + startSignal.getCount());
        System.out.println(Thread.currentThread().getName() + " doneSignal.count:" + doneSignal.getCount());
        // 此时结束信号对应的状态值不为0,主线程阻塞,等待所有工作线程执行完。
        doneSignal.await();
        // 最终信号的计数都为0。
        System.out.println(Thread.currentThread().getName() + " startSignal.count:" + startSignal.getCount());
        System.out.println(Thread.currentThread().getName() + " doneSignal.count:" + doneSignal.getCount());
    }
}

执行结果

main startSignal.count:1
main doneSignal.count:5
main startSignal.count:0
main doneSignal.count:5
Thread-0 startSignal.count:0
Thread-3 startSignal.count:0
Thread-2 startSignal.count:0
Thread-1 startSignal.count:0
Thread-4 startSignal.count:0
Thread-1 doneSignal.count:5
Thread-2 doneSignal.count:5
Thread-3 doneSignal.count:5
Thread-0 doneSignal.count:5
Thread-4 doneSignal.count:5
main startSignal.count:0
main doneSignal.count:0

使用场景二

另一个典型的用法是将一个问题分成 N 个部分,每个部分用一个 Runnable 进行描述,该 Runnable 执行该部分并在锁存器上倒计时,并将所有 Runnables 排队到一个 Executor。 当所有子部分都完成后,协调线程就可以通过await了。

import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

public class Driver2 {
    static class WorkerRunnable implements Runnable {
        private final CountDownLatch doneSignal;
        private final int i;

        WorkerRunnable(CountDownLatch doneSignal, int i) {
            this.doneSignal = doneSignal;
            this.i = i;
        }

        @Override
        public void run() {

            doWork(i);
            doneSignal.countDown();
            System.out.println(Thread.currentThread().getName() + " doneSignal.count:" + doneSignal.getCount());

        }

        void doWork(int i) {
        }
    }

    public static void main(String[] args) throws InterruptedException {
        CountDownLatch doneSignal = new CountDownLatch(5);
        ExecutorService e = Executors.newSingleThreadExecutor();
        System.out.println(Thread.currentThread().getName() + " doneSignal.count:" + doneSignal.getCount());
        for (int i = 0; i < 5; ++i) {
            e.execute(new WorkerRunnable(doneSignal, i));
        }
        System.out.println(Thread.currentThread().getName() + " doneSignal.count:" + doneSignal.getCount());
        doneSignal.await();
        System.out.println(Thread.currentThread().getName() + " doneSignal.count:" + doneSignal.getCount());
        e.shutdown();
    }
}

运行结果,第7行比第8行优先输出的原因是在最后一个工作线程调用countDown()后,主线程已经被唤醒立即执行了,而工作线程的后续输出可能就慢与主线程。所以常规情况下,工作线程中countDown()的调用需要注意位置。

main doneSignal.count:5
main doneSignal.count:5
pool-1-thread-1 doneSignal.count:4
pool-1-thread-1 doneSignal.count:3
pool-1-thread-1 doneSignal.count:2
pool-1-thread-1 doneSignal.count:1
pool-1-thread-1 doneSignal.count:0
main doneSignal.count:0

源码解析

public class CountDownLatch {
    /**
     * Synchronization control For CountDownLatch.
     * Uses AQS state to represent count.
     */
    private static final class Sync extends AbstractQueuedSynchronizer {
        private static final long serialVersionUID = 4982264981922014374L;
        // 构造器-设置同步状态值
        Sync(int count) {
            setState(count);
        }
		// 获取当前同步状态值
        int getCount() {
            return getState();
        }
        // 尝试获取锁,同步状态值为0返回1 代表获取到锁
        protected int tryAcquireShared(int acquires) {
            return (getState() == 0) ? 1 : -1;
        }
		// 尝试释放锁
        protected boolean tryReleaseShared(int releases) {
            // 当锁的个数为0就以false退出,无需进行锁释放,如果释放锁之后,同步状态为0代表着锁完全释放,否则说明锁没有完全释放掉。
            for (;;) {
                int c = getState();
                if (c == 0)
                    return false;
                int nextc = c-1;
                if (compareAndSetState(c, nextc))
                    return nextc == 0;
            }
        }
    }

    private final Sync sync;

    /**
     * 构造器
     */
    public CountDownLatch(int count) {
        if (count < 0) throw new IllegalArgumentException("count < 0");
        this.sync = new Sync(count);
    }

    /**
     * 响应中断式尝试获取锁
     */
    public void await() throws InterruptedException {
        sync.acquireSharedInterruptibly(1);
    }

    /**
     * 包含等待超时时间,响应中断式尝试获取锁
     */
    public boolean await(long timeout, TimeUnit unit)
        throws InterruptedException {
        return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
    }

    /**
     * 释放锁
     */
    public void countDown() {
        sync.releaseShared(1);
    }

    /**
     * 获取同步状态的值,当前可用锁的个数
     */
    public long getCount() {
        return sync.getCount();
    }

    /**
     */
    public String toString() {
        return super.toString() + "[Count = " + sync.getCount() + "]";
    }
}

举报

相关推荐

0 条评论