From c188304efa232751679789b68d06ba1390c13d34 Mon Sep 17 00:00:00 2001 From: Wang Zhiyang <1208931582@qq.com> Date: Thu, 18 Jan 2024 03:49:51 +0800 Subject: [PATCH] GH-2968: Fix DEH#handleBatchAndReturnRemaining DefaultErrorHandler#handleBatchAndReturnRemaining recovered invalid and infinite loop when kafka listener threw BatchListenerFailedException and error record is first one in remaining list * address empty catch * add unit test Co-authored-by: Zhiyang.Wang1 --- .../kafka/listener/FailedBatchProcessor.java | 71 +++---- ...ErrorHandlerNoSeeksBatchListenerTests.java | 180 +++++++++++++----- 2 files changed, 167 insertions(+), 84 deletions(-) diff --git a/spring-kafka/src/main/java/org/springframework/kafka/listener/FailedBatchProcessor.java b/spring-kafka/src/main/java/org/springframework/kafka/listener/FailedBatchProcessor.java index 58f3d48d6d..3ee7da9f90 100644 --- a/spring-kafka/src/main/java/org/springframework/kafka/listener/FailedBatchProcessor.java +++ b/spring-kafka/src/main/java/org/springframework/kafka/listener/FailedBatchProcessor.java @@ -16,12 +16,10 @@ package org.springframework.kafka.listener; -import java.time.Duration; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; -import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Set; @@ -37,6 +35,7 @@ import org.springframework.kafka.KafkaException; import org.springframework.kafka.KafkaException.Level; +import org.springframework.kafka.support.KafkaUtils; import org.springframework.lang.Nullable; import org.springframework.util.backoff.BackOff; @@ -50,6 +49,7 @@ * * @author Gary Russell * @author Francois Rosiere + * @author Wang Zhiyang * @since 2.8 * */ @@ -120,7 +120,7 @@ public void setReclassifyOnExceptionChange(boolean reclassifyOnExceptionChange) @Override protected void notRetryable(Stream> notRetryable) { if (this.fallbackBatchHandler instanceof ExceptionClassifier handler) { - notRetryable.forEach(ex -> handler.addNotRetryableExceptions(ex)); + notRetryable.forEach(handler::addNotRetryableExceptions); } } @@ -178,7 +178,6 @@ protected ConsumerRecords handle(Exception thrownException, Consume else { return String.format("Record not found in batch, index %d out of bounds (0, %d); " + "re-seeking batch", index, data.count() - 1); - } }); fallback(thrownException, data, consumer, container, invokeListener); @@ -201,11 +200,9 @@ private int findIndex(ConsumerRecords data, ConsumerRecord record) { return -1; } int i = 0; - Iterator iterator = data.iterator(); - while (iterator.hasNext()) { - ConsumerRecord candidate = (ConsumerRecord) iterator.next(); - if (candidate.topic().equals(record.topic()) && candidate.partition() == record.partition() - && candidate.offset() == record.offset()) { + for (ConsumerRecord datum : data) { + if (datum.topic().equals(record.topic()) && datum.partition() == record.partition() + && datum.offset() == record.offset()) { break; } i++; @@ -220,29 +217,25 @@ private ConsumerRecords seekOrRecover(Exception thrownException, @N if (data == null) { return ConsumerRecords.empty(); } - Iterator iterator = data.iterator(); - List> toCommit = new ArrayList<>(); List> remaining = new ArrayList<>(); int index = indexArg; - while (iterator.hasNext()) { - ConsumerRecord record = (ConsumerRecord) iterator.next(); + Map offsets = new HashMap<>(); + for (ConsumerRecord datum : data) { if (index-- > 0) { - toCommit.add(record); + offsets.compute(new TopicPartition(datum.topic(), datum.partition()), + (key, val) -> ListenerUtils.createOffsetAndMetadata(container, datum.offset() + 1)); } else { - remaining.add(record); + remaining.add(datum); } } - Map offsets = new HashMap<>(); - toCommit.forEach(rec -> offsets.compute(new TopicPartition(rec.topic(), rec.partition()), - (key, val) -> ListenerUtils.createOffsetAndMetadata(container, rec.offset() + 1))); if (offsets.size() > 0) { commit(consumer, container, offsets); } if (isSeekAfterError()) { if (remaining.size() > 0) { SeekUtils.seekOrRecover(thrownException, remaining, consumer, container, false, - getFailureTracker()::recovered, this.logger, getLogLevel()); + getFailureTracker(), this.logger, getLogLevel()); ConsumerRecord recovered = remaining.get(0); commit(consumer, container, Collections.singletonMap(new TopicPartition(recovered.topic(), recovered.partition()), @@ -254,10 +247,7 @@ private ConsumerRecords seekOrRecover(Exception thrownException, @N return ConsumerRecords.empty(); } else { - if (indexArg == 0) { - return (ConsumerRecords) data; // first record just rerun the whole thing - } - else { + if (remaining.size() > 0) { try { if (getFailureTracker().recovered(remaining.get(0), thrownException, container, consumer)) { @@ -265,24 +255,35 @@ private ConsumerRecords seekOrRecover(Exception thrownException, @N } } catch (Exception e) { + if (SeekUtils.isBackoffException(thrownException)) { + this.logger.debug(e, () -> KafkaUtils.format(remaining.get(0)) + + " included in remaining due to retry back off " + thrownException); + } + else { + this.logger.error(e, KafkaUtils.format(remaining.get(0)) + + " included in remaining due to " + thrownException); + } } - Map>> remains = new HashMap<>(); - remaining.forEach(rec -> remains.computeIfAbsent(new TopicPartition(rec.topic(), rec.partition()), - tp -> new ArrayList>()).add((ConsumerRecord) rec)); - return new ConsumerRecords<>(remains); } + if (remaining.isEmpty()) { + return ConsumerRecords.empty(); + } + Map>> remains = new HashMap<>(); + remaining.forEach(rec -> remains.computeIfAbsent(new TopicPartition(rec.topic(), rec.partition()), + tp -> new ArrayList<>()).add((ConsumerRecord) rec)); + return new ConsumerRecords<>(remains); } } - private void commit(Consumer consumer, MessageListenerContainer container, Map offsets) { + private void commit(Consumer consumer, MessageListenerContainer container, + Map offsets) { - boolean syncCommits = container.getContainerProperties().isSyncCommits(); - Duration timeout = container.getContainerProperties().getSyncCommitTimeout(); - if (syncCommits) { - consumer.commitSync(offsets, timeout); + ContainerProperties properties = container.getContainerProperties(); + if (properties.isSyncCommits()) { + consumer.commitSync(offsets, properties.getSyncCommitTimeout()); } else { - OffsetCommitCallback commitCallback = container.getContainerProperties().getCommitCallback(); + OffsetCommitCallback commitCallback = properties.getCommitCallback(); if (commitCallback == null) { commitCallback = LOGGING_COMMIT_CALLBACK; } @@ -304,8 +305,8 @@ private BatchListenerFailedException getBatchListenerFailedException(Throwable t throwable = throwable.getCause(); checked.add(throwable); - if (throwable instanceof BatchListenerFailedException) { - target = (BatchListenerFailedException) throwable; + if (throwable instanceof BatchListenerFailedException batchListenerFailedException) { + target = batchListenerFailedException; break; } } diff --git a/spring-kafka/src/test/java/org/springframework/kafka/listener/DefaultErrorHandlerNoSeeksBatchListenerTests.java b/spring-kafka/src/test/java/org/springframework/kafka/listener/DefaultErrorHandlerNoSeeksBatchListenerTests.java index 20f29e969b..0af228ae86 100644 --- a/spring-kafka/src/test/java/org/springframework/kafka/listener/DefaultErrorHandlerNoSeeksBatchListenerTests.java +++ b/spring-kafka/src/test/java/org/springframework/kafka/listener/DefaultErrorHandlerNoSeeksBatchListenerTests.java @@ -19,7 +19,6 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyMap; -import static org.mockito.ArgumentMatchers.isNull; import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.willAnswer; import static org.mockito.BDDMockito.willReturn; @@ -46,7 +45,6 @@ import org.apache.kafka.clients.consumer.ConsumerRecord; import org.apache.kafka.clients.consumer.ConsumerRecords; import org.apache.kafka.clients.consumer.OffsetAndMetadata; -import org.apache.kafka.clients.producer.Producer; import org.apache.kafka.common.TopicPartition; import org.apache.kafka.common.header.internals.RecordHeaders; import org.apache.kafka.common.record.TimestampType; @@ -61,30 +59,28 @@ import org.springframework.kafka.config.ConcurrentKafkaListenerContainerFactory; import org.springframework.kafka.config.KafkaListenerEndpointRegistry; import org.springframework.kafka.core.ConsumerFactory; -import org.springframework.kafka.core.ProducerFactory; import org.springframework.kafka.test.utils.KafkaTestUtils; import org.springframework.test.annotation.DirtiesContext; import org.springframework.test.context.junit.jupiter.SpringJUnitConfig; +import org.springframework.util.backoff.FixedBackOff; /** * @author Gary Russell + * @author Wang Zhiyang * @since 2.9 * */ @SpringJUnitConfig @DirtiesContext -@SuppressWarnings("deprecation") public class DefaultErrorHandlerNoSeeksBatchListenerTests { private static final String CONTAINER_ID = "container"; - @SuppressWarnings("rawtypes") - @Autowired - private Consumer consumer; + private static final String CONTAINER_ID_2 = "container2"; @SuppressWarnings("rawtypes") @Autowired - private Producer producer; + private Consumer consumer; @Autowired private Config config; @@ -104,7 +100,7 @@ void retriesWithNoSeeksBatchListener() throws Exception { assertThat(this.config.commitLatch.await(10, TimeUnit.SECONDS)).isTrue(); this.registry.stop(); assertThat(this.config.closeLatch.await(10, TimeUnit.SECONDS)).isTrue(); - InOrder inOrder = inOrder(this.consumer, this.producer); + InOrder inOrder = inOrder(this.consumer); inOrder.verify(this.consumer).subscribe(any(Collection.class), any(ConsumerRebalanceListener.class)); inOrder.verify(this.consumer).poll(Duration.ofMillis(ContainerProperties.DEFAULT_POLL_TIMEOUT)); Map offsets = new LinkedHashMap<>(); @@ -123,25 +119,45 @@ void retriesWithNoSeeksBatchListener() throws Exception { assertThat(this.config.contents).contains("foo", "bar", "baz", "qux", "qux", "qux", "fiz", "buz"); } + /* + * Deliver 6 records from three partitions, fail on the last record + */ + @Test + void retriesWithNoSeeksAndBatchListener2() throws Exception { + assertThat(this.config.pollLatch2.await(10, TimeUnit.SECONDS)).isTrue(); + assertThat(this.config.fooLatch2.await(10, TimeUnit.SECONDS)).isTrue(); + assertThat(this.config.deliveryCount.get()).isEqualTo(4); + assertThat(this.config.ehException2).isInstanceOf(ListenerExecutionFailedException.class); + assertThat(((ListenerExecutionFailedException) this.config.ehException2).getGroupId()).isEqualTo(CONTAINER_ID_2); + } + @Configuration @EnableKafka public static class Config { final CountDownLatch pollLatch = new CountDownLatch(1); + final CountDownLatch pollLatch2 = new CountDownLatch(1); + final CountDownLatch deliveryLatch = new CountDownLatch(2); final CountDownLatch closeLatch = new CountDownLatch(1); final CountDownLatch commitLatch = new CountDownLatch(2); + final AtomicInteger deliveryCount = new AtomicInteger(0); + + final CountDownLatch fooLatch2 = new CountDownLatch(1); + final AtomicBoolean fail = new AtomicBoolean(true); final List contents = new ArrayList<>(); volatile Exception ehException; - @KafkaListener(id = CONTAINER_ID, topics = "foo") + volatile Exception ehException2; + + @KafkaListener(id = CONTAINER_ID, topics = "foo", containerFactory = "kafkaListenerContainerFactory") public void foo(List in) { this.contents.addAll(in); this.deliveryLatch.countDown(); @@ -150,6 +166,19 @@ public void foo(List in) { } } + @KafkaListener(id = CONTAINER_ID_2, topics = "foo2", containerFactory = "kafkaListenerContainerFactory2") + public void foo2(List in) { + deliveryCount.incrementAndGet(); + int index = 0; + for (String str : in) { + if ("qux".equals(str)) { + throw new BatchListenerFailedException("test", index); + } + index++; + } + fooLatch2.countDown(); + } + @SuppressWarnings({ "rawtypes" }) @Bean public ConsumerFactory consumerFactory() { @@ -164,30 +193,7 @@ public ConsumerFactory consumerFactory() { @Bean public Consumer consumer() { final Consumer consumer = mock(Consumer.class); - final TopicPartition topicPartition0 = new TopicPartition("foo", 0); - final TopicPartition topicPartition1 = new TopicPartition("foo", 1); - final TopicPartition topicPartition2 = new TopicPartition("foo", 2); - willAnswer(i -> { - ((ConsumerRebalanceListener) i.getArgument(1)).onPartitionsAssigned( - Collections.singletonList(topicPartition1)); - return null; - }).given(consumer).subscribe(any(Collection.class), any(ConsumerRebalanceListener.class)); - Map> records1 = new LinkedHashMap<>(); - records1.put(topicPartition0, Arrays.asList( - new ConsumerRecord("foo", 0, 0L, 0L, TimestampType.NO_TIMESTAMP_TYPE, 0, 0, null, "foo", - new RecordHeaders(), Optional.empty()), - new ConsumerRecord("foo", 0, 1L, 0L, TimestampType.NO_TIMESTAMP_TYPE, 0, 0, null, "bar", - new RecordHeaders(), Optional.empty()))); - records1.put(topicPartition1, Arrays.asList( - new ConsumerRecord("foo", 1, 0L, 0L, TimestampType.NO_TIMESTAMP_TYPE, 0, 0, null, "baz", - new RecordHeaders(), Optional.empty()), - new ConsumerRecord("foo", 1, 1L, 0L, TimestampType.NO_TIMESTAMP_TYPE, 0, 0, null, "qux", - new RecordHeaders(), Optional.empty()))); - records1.put(topicPartition2, Arrays.asList( - new ConsumerRecord("foo", 2, 0L, 0L, TimestampType.NO_TIMESTAMP_TYPE, 0, 0, null, "fiz", - new RecordHeaders(), Optional.empty()), - new ConsumerRecord("foo", 2, 1L, 0L, TimestampType.NO_TIMESTAMP_TYPE, 0, 0, null, "buz", - new RecordHeaders(), Optional.empty()))); + Map> records1 = createRecords(consumer, "foo"); final AtomicInteger which = new AtomicInteger(); willAnswer(i -> { this.pollLatch.countDown(); @@ -218,9 +224,63 @@ public Consumer consumer() { @SuppressWarnings({ "rawtypes", "unchecked" }) @Bean + public Consumer consumer2() { + final Consumer consumer = mock(Consumer.class); + Map> records1 = createRecords(consumer, "foo2"); + final TopicPartition topicPartition0 = new TopicPartition("foo2", 0); + Map> records2 = new LinkedHashMap<>(); + records2.put(topicPartition0, List.of( + new ConsumerRecord("foo2", 1, 2L, 0L, TimestampType.NO_TIMESTAMP_TYPE, 0, 0, null, "foo", + new RecordHeaders(), Optional.empty()))); + final AtomicInteger which = new AtomicInteger(); + willAnswer(i -> { + this.pollLatch2.countDown(); + switch (which.getAndIncrement()) { + case 0: + return new ConsumerRecords(records1); + case 3: // after backoff + return new ConsumerRecords(records2); + default: + try { + Thread.sleep(0); + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + return new ConsumerRecords(Collections.emptyMap()); + } + }).given(consumer).poll(any()); + willReturn(new ConsumerGroupMetadata(CONTAINER_ID_2)).given(consumer).groupMetadata(); + return consumer; + } + + @SuppressWarnings({ "rawtypes" }) + @Bean + public ConsumerFactory consumerFactory2() { + ConsumerFactory consumerFactory = mock(ConsumerFactory.class); + final Consumer consumer = consumer2(); + given(consumerFactory.createConsumer(CONTAINER_ID_2, "", "-0", KafkaTestUtils.defaultPropertyOverrides())) + .willReturn(consumer); + return consumerFactory; + } + + @SuppressWarnings({ "rawtypes"}) + @Bean public ConcurrentKafkaListenerContainerFactory kafkaListenerContainerFactory() { + return createConcurrentKafkaListenerContainerFactory(consumerFactory(), CONTAINER_ID); + } + + @SuppressWarnings({ "rawtypes"}) + @Bean + public ConcurrentKafkaListenerContainerFactory kafkaListenerContainerFactory2() { + return createConcurrentKafkaListenerContainerFactory(consumerFactory2(), CONTAINER_ID_2); + } + + @SuppressWarnings({ "rawtypes", "unchecked" }) + private ConcurrentKafkaListenerContainerFactory createConcurrentKafkaListenerContainerFactory( + ConsumerFactory consumerFactory, String id) { ConcurrentKafkaListenerContainerFactory factory = new ConcurrentKafkaListenerContainerFactory(); - factory.setConsumerFactory(consumerFactory()); + factory.setConsumerFactory(consumerFactory); factory.setBatchListener(true); factory.getContainerProperties().setPollTimeoutWhilePaused(Duration.ZERO); DefaultErrorHandler eh = new DefaultErrorHandler() { @@ -230,29 +290,51 @@ public ConsumerRecords handleBatchAndReturnRemaining(Exception thro ConsumerRecords data, Consumer consumer, MessageListenerContainer container, Runnable invokeListener) { - Config.this.ehException = thrownException; + if (id.equals(CONTAINER_ID)) { + Config.this.ehException = thrownException; + } + else { + Config.this.ehException2 = thrownException; + } return super.handleBatchAndReturnRemaining(thrownException, data, consumer, container, invokeListener); } }; eh.setSeekAfterError(false); + if (id.equals(CONTAINER_ID_2)) { + eh.setBackOffFunction((rc, ex) -> new FixedBackOff(0, 2)); + } factory.setCommonErrorHandler(eh); return factory; } - @SuppressWarnings("rawtypes") - @Bean - public ProducerFactory producerFactory() { - ProducerFactory pf = mock(ProducerFactory.class); - given(pf.createProducer(isNull())).willReturn(producer()); - given(pf.transactionCapable()).willReturn(true); - return pf; - } - - @SuppressWarnings("rawtypes") - @Bean - public Producer producer() { - return mock(Producer.class); + @SuppressWarnings({ "rawtypes", "unchecked" }) + private Map> createRecords(Consumer consumer, String topic) { + final TopicPartition topicPartition0 = new TopicPartition(topic, 0); + final TopicPartition topicPartition1 = new TopicPartition(topic, 1); + final TopicPartition topicPartition2 = new TopicPartition(topic, 2); + willAnswer(i -> { + ((ConsumerRebalanceListener) i.getArgument(1)).onPartitionsAssigned( + Collections.singletonList(topicPartition1)); + return null; + }).given(consumer).subscribe(any(Collection.class), any(ConsumerRebalanceListener.class)); + Map> records1 = new LinkedHashMap<>(); + records1.put(topicPartition0, Arrays.asList( + new ConsumerRecord(topic, 0, 0L, 0L, TimestampType.NO_TIMESTAMP_TYPE, 0, 0, null, "foo", + new RecordHeaders(), Optional.empty()), + new ConsumerRecord(topic, 0, 1L, 0L, TimestampType.NO_TIMESTAMP_TYPE, 0, 0, null, "bar", + new RecordHeaders(), Optional.empty()))); + records1.put(topicPartition1, Arrays.asList( + new ConsumerRecord(topic, 1, 0L, 0L, TimestampType.NO_TIMESTAMP_TYPE, 0, 0, null, "baz", + new RecordHeaders(), Optional.empty()), + new ConsumerRecord(topic, 1, 1L, 0L, TimestampType.NO_TIMESTAMP_TYPE, 0, 0, null, "qux", + new RecordHeaders(), Optional.empty()))); + records1.put(topicPartition2, Arrays.asList( + new ConsumerRecord(topic, 2, 0L, 0L, TimestampType.NO_TIMESTAMP_TYPE, 0, 0, null, "fiz", + new RecordHeaders(), Optional.empty()), + new ConsumerRecord(topic, 2, 1L, 0L, TimestampType.NO_TIMESTAMP_TYPE, 0, 0, null, "buz", + new RecordHeaders(), Optional.empty()))); + return records1; } }