当前位置: 首页>后端>正文

多线程转批量调用代码优化(二)

继上篇之后差不多两年了,最近项目又重构了波代码,重新回看之前的代码,虽然逻辑能看懂,但是大量的map/multimap还是影响了一些代码可读性,同时抽象和封装上还可以进一步抽取,又重新写了一版,功能基本类似。重新梳理了下逻辑后,觉得原来的标题《算法调用优化演变》有点不适合,重新取了个标题名。

多线程转批量

多线程转批量应该算是一个基础需求,目前项目除了nlu算法调用需要用到多线程转批量外,有一些后台任务查询clickhouse也需要用到(clickhouse不像mysql,多线程性能差,需要批量/插入)。同时如果抽象成一个工具类的话,consumer消费应该由实际调用者决定是单线程还是多线程消费的,只要提供一个线程安全的consume方法就可以了。

/**
 * 多线程转批量工具。可自定义每个request大小,可以实现CustomSize接口。
 *
 * @param <T> 请求类型
 * @param <R> 结果类型
 * @see CustomSize
 */
@Slf4j
@FieldDefaults(makeFinal = true)
public class ConcurrencyToBatch<T, R> {

    BlockingQueue<RequestWrapper> requestWaitQueue = new LinkedBlockingQueue<>();

    ExecutorService producerThreadPool;

    Storage storage;

    /**
     * @param batchSize       批量大小
     * @param batchLinger     批量收集时间,单位:毫秒
     * @param storageSize     storage大小。建议等于消费线程数+1。
     * @param mergeThreadName 生产者线程名
     */
    public ConcurrencyToBatch(int batchSize, long batchLinger, int storageSize, String mergeThreadName) {
        this.producerThreadPool = Executors.newSingleThreadExecutor(new CustomizableThreadFactory(mergeThreadName));

        this.storage = new Storage(storageSize);
        BatchController batchController = new BatchController(batchLinger, batchSize);
        Producer producer = new Producer(batchController);
        this.producerThreadPool.execute(producer);
    }

    /**
     * 提交一个请求,返回CompletableFuture。consume批处理完成后,future可获取到结果
     */
    public CompletableFuture<R> commit(T request) {
        CompletableFuture<R> future = new CompletableFuture<>();
        RequestWrapper wrapper = new RequestWrapper(request, future);
        while (!requestWaitQueue.offer(wrapper)) {
            // 无界队列,理论上不会进这里
            log.warn("队列已满,等待100ms");
            ThreadUtils.sleep(100);
        }
        return future;
    }

    /**
     * 消费一批数据。传入回调函数,需要将关联的future置为complete
     *
     * @param consumer 回调函数。注意返回的List<R>顺序需要与List<T>保持一对一关系
     */
    public void consume(Function<List<T>, List<R>> consumer) throws InterruptedException {
        Product product = storage.pop();
        List<RequestWrapper> batch = product.getBatch();
        List<T> requests = StreamUtils.map(batch, RequestWrapper::getRequest);
        List<R> results = null;
        try {
            results = consumer.apply(requests);
        } catch (Throwable ignore) {}

        if (results == null || batch.size() != results.size()) {
            log.warn("响应数据缺失");
            batch.forEach(i -> i.getFuture().complete(null));
        } else {
            for (int i = 0; i < batch.size(); i++) {
                batch.get(i).getFuture().complete(results.get(i));
            }
        }

    }

    public void destroy() {
        producerThreadPool.shutdown();
    }

    public interface CustomSize {
        int size();
    }

    @RequiredArgsConstructor
    @FieldDefaults(makeFinal = true)
    private class Producer implements Runnable {

        BatchController batchController;

        @Override
        @SuppressWarnings("InfiniteLoopStatement")
        public void run() {
            while (true) {
                try {
                    RequestWrapper firstRequest = requestWaitQueue.take();
                    this.batchController.init();
                    this.batchController.add(firstRequest);
                    while (this.batchController.sizeInLimit()) {
                        long timeout = this.batchController.getMaxWait() - System.currentTimeMillis();
                        RequestWrapper request = requestWaitQueue.poll(timeout, TimeUnit.MILLISECONDS);
                        if (request != null) {
                            this.batchController.add(request);
                        } else {
                            break;
                        }
                    }
                    Product product = new Product(this.batchController.getBatch());
                    if (log.isDebugEnabled()) {
                        int size = product.getBatch().size();
                        log.debug("时间到或满一批,发送到storage, size:{}", size);
                    }
                    storage.push(product);
                } catch (Throwable t) {
                    log.error("生产者处理异常", t);
                }
            }
        }

    }

    private class BatchController {

        @Getter
        List<RequestWrapper> batch;

        int currSize;

        @Getter
        long maxWait;

        final long batchLinger;

        final int batchSize;

        public BatchController(long batchLinger, int batchSize) {
            this.batch = new ArrayList<>();
            this.currSize = 0;

            this.batchLinger = batchLinger;
            this.batchSize = batchSize;
        }

        public void init() {
            this.batch = new ArrayList<>();
            this.currSize = 0;
            this.maxWait = System.currentTimeMillis() + batchLinger;
        }

        public void add(RequestWrapper request) {
            this.batch.add(request);
            T sourceRequest = request.getRequest();
            if (sourceRequest instanceof CustomSize) {
                this.currSize += ((CustomSize) sourceRequest).size();
            } else {
                this.currSize += 1;
            }
        }

        public boolean sizeInLimit() {
            return this.currSize < batchSize;
        }
    }

    @FieldDefaults(makeFinal = true)
    private class Storage {
        BlockingQueue<Product> queues;

        public Storage(int storageSize) {
            this.queues = new LinkedBlockingQueue<>(storageSize);
        }

        public void push(Product p) throws InterruptedException {
            queues.put(p);
        }

        public Product pop() throws InterruptedException {
            return queues.take();
        }
    }

    @Getter
    @AllArgsConstructor
    @FieldDefaults(makeFinal = true)
    private class Product {
        List<RequestWrapper> batch;
    }

    @Getter
    @AllArgsConstructor
    @FieldDefaults(makeFinal = true)
    private class RequestWrapper {
        T request;

        CompletableFuture<R> future;
    }
}

与以前代码相比,使用CompletableFuture,干掉了Map<F, J> processResultsMap。用RequestWrapper封装request和future的关联关系,干掉了ConcurrentHashMap<F, FutureTask<J>> requestFutureMap。同时把fork/join的逻辑分割开,使得多线程转批量可以单独使用。单独提供CustomSize接口,有需要的情况下再定制request的size。

多线程消费封装

由于项目里实际消费还是都是采用多线程的,因此还是封装了一个多线程消费client,继承一下,减少重复代码。

/**
 * 对ConcurrencyToBatch的封装,增加消费线程池初始化及处理
 * @see ConcurrencyToBatch
 */
@Slf4j
@FieldDefaults(makeFinal = true)
public abstract class AbstractConcurrencyToBatchClient<T, R> implements DisposableBean {

    ConcurrencyToBatch<T, R> concurrencyToBatch;

    ExecutorService consumerThreadPool;

    public AbstractConcurrencyToBatchClient(int batchSize, long batchLinger, String mergeThreadName,
                                            int consumerPoolSize, String consumerThreadName) {
        this.concurrencyToBatch = new ConcurrencyToBatch<>(batchSize, batchLinger, consumerPoolSize + 1, mergeThreadName);

        this.consumerThreadPool = Executors.newFixedThreadPool(consumerPoolSize, new CustomizableThreadFactory(consumerThreadName));
        for (int i = 0; i < consumerPoolSize; i++) {
            this.consumerThreadPool.execute(this::consume);
        }
    }

    @Override
    public void destroy() {
        this.concurrencyToBatch.destroy();
        this.consumerThreadPool.shutdown();
    }

    protected CompletableFuture<R> commit(T request) {
        return this.concurrencyToBatch.commit(request);
    }

    @SuppressWarnings("all")
    private void consume() {
        while (true) {
            try {
                log.debug("等待批处理数据");
                this.concurrencyToBatch.consume(this::internalConsume);
                log.debug("批处理完成");
            } catch (Throwable t) {
                log.error("消费异常", t);
            }
        }
    }

    protected abstract List<R> internalConsume(List<T> requests);

}

fork/join的支持

fork/join的接口设计保持不变。由于使用CompletableFuture替代了FutureTask,CompletableFuture的api可以完美干掉剩下的map/multimap (Multimap<F, F> forkRequestMultimap,Map<F, AtomicInteger> forkRequestUndoMap,Map<F, F> forkedRequestToSourceRequestMap),而且代码也更加简洁。差异之处在于以前其实是保证了fork出来的子request在等待队列里是连续的,批处理后会第一时间返回,但是新代码是fork完后一个个调用commit,多线程的情况下并不保证是连续的。不过实际使用差异几乎没有。

/**
 * 对AbstractConcurrencyToBatchClient的进一步封装,增加fork/join的支持
 * 部分request会因为内容过多,导致请求超时,因此需要把大请求fork成多个子请求,然后再将结果join合并
 *
 * @see AbstractConcurrencyToBatchClient
 * @see Forkable
 * @see Joinable
 */
@Slf4j
@FieldDefaults(makeFinal = true)
public abstract class AbstractConcurrencyToBatchForkJoinClient<F extends Forkable<F>, J extends Joinable<J>>
        extends AbstractConcurrencyToBatchClient<F, J> implements DisposableBean {

    int partitionCapacity;

    public AbstractConcurrencyToBatchForkJoinClient(int batchSize, long batchLinger, String mergeThreadName,
                                                    int consumerPoolSize, String consumerThreadName,
                                                    int partitionCapacity) {
        super(batchSize, batchLinger, mergeThreadName, consumerPoolSize, consumerThreadName);
        this.partitionCapacity = partitionCapacity;
    }

    @Override
    protected CompletableFuture<J> commit(F request) {
        if (!request.shouldFork(this.partitionCapacity)) {
            return super.commit(request);
        }

        List<F> forks = request.fork(this.partitionCapacity);
        Preconditions.checkArgument(forks.size() > 0);
        CompletableFuture<J> future = null;
        for (F fork : forks) {
            CompletableFuture<J> theForkFuture = super.commit(fork);
            if (future == null) {
                future = theForkFuture.thenApply(j -> j.resume(fork.getOffset()));
            } else {
                future = future.thenCompose(join-> theForkFuture.thenApply(theForkJoin -> join.resumeThenJoin(theForkJoin, fork.getOffset())));
            }
        }

        Preconditions.checkNotNull(future);
        return future;
    }

}


https://www.xamrdz.com/backend/3z91940683.html

相关文章: