public class RocketMQSourceDemo <OUT> extends RichParallelSourceFunction<OUT>
implements CheckpointedFunction, CheckpointListener, ResultTypeQueryable<OUT> {
private static final long serialVersionUID = 1L;
private static final Logger LOG = LoggerFactory.getLogger(RocketMQSource.class);
private transient MQPullConsumerScheduleService pullConsumerScheduleService;
private DefaultMQPullConsumer consumer;
private KeyValueDeserializationSchema<OUT> schema;
private RunningChecker runningChecker;
private transient ListState<Tuple2<MessageQueue, Long>> unionOffsetStates;
private Map<MessageQueue, Long> offsetTable;
private Map<MessageQueue, Long> restoredOffsets;
private LinkedMap pendingOffsetsToCommit;
private Properties props;
private String topic;
private String group;
private static final String OFFSETS_STATE_NAME = "topic-partition-offset-states";
private transient volatile boolean restored;
private transient boolean enableCheckpoint;
public RocketMQSource(KeyValueDeserializationSchema<OUT> schema, Properties props) {
this.schema = schema;
this.props = props;
}
@Override
public void open(Configuration parameters) throws Exception {
LOG.debug("source open....");
Validate.notEmpty(props, "Consumer properties can not be empty");
Validate.notNull(schema, "KeyValueDeserializationSchema can not be null");
this.topic = props.getProperty(RocketMQConfig.CONSUMER_TOPIC);
this.group = props.getProperty(RocketMQConfig.CONSUMER_GROUP);
Validate.notEmpty(topic, "Consumer topic can not be empty");
Validate.notEmpty(group, "Consumer group can not be empty");
this.enableCheckpoint = ((StreamingRuntimeContext) getRuntimeContext()).isCheckpointingEnabled();
if (offsetTable == null) {
offsetTable = new ConcurrentHashMap<>();
}
if (restoredOffsets == null) {
restoredOffsets = new ConcurrentHashMap<>();
}
if (pendingOffsetsToCommit == null) {
pendingOffsetsToCommit = new LinkedMap();
}
runningChecker = new RunningChecker();
pullConsumerScheduleService = new MQPullConsumerScheduleService(group);
consumer = pullConsumerScheduleService.getDefaultMQPullConsumer();
consumer.setInstanceName(String.valueOf(getRuntimeContext().getIndexOfThisSubtask()) + "_" + UUID.randomUUID());
RocketMQConfig.buildConsumerConfigs(props, consumer);
}
@Override
public void run(SourceContext context) throws Exception {
LOG.debug("source run....");
final Object lock = context.getCheckpointLock();
int delayWhenMessageNotFound = getInteger(props, RocketMQConfig.CONSUMER_DELAY_WHEN_MESSAGE_NOT_FOUND,
RocketMQConfig.DEFAULT_CONSUMER_DELAY_WHEN_MESSAGE_NOT_FOUND);
String tag = props.getProperty(RocketMQConfig.CONSUMER_TAG, RocketMQConfig.DEFAULT_CONSUMER_TAG);
int pullPoolSize = getInteger(props, RocketMQConfig.CONSUMER_PULL_POOL_SIZE,
RocketMQConfig.DEFAULT_CONSUMER_PULL_POOL_SIZE);
int pullBatchSize = getInteger(props, RocketMQConfig.CONSUMER_BATCH_SIZE,
RocketMQConfig.DEFAULT_CONSUMER_BATCH_SIZE);
pullConsumerScheduleService.setPullThreadNums(pullPoolSize);
pullConsumerScheduleService.registerPullTaskCallback(topic, new PullTaskCallback() {
@Override
public void doPullTask(MessageQueue mq, PullTaskContext pullTaskContext) {
try {
long offset = getMessageQueueOffset(mq);
if (offset < 0) {
return;
}
PullResult pullResult = consumer.pull(mq, tag, offset, pullBatchSize);
boolean found = false;
switch (pullResult.getPullStatus()) {
case FOUND:
List<MessageExt> messages = pullResult.getMsgFoundList();
for (MessageExt msg : messages) {
byte[] key = msg.getKeys() != null ? msg.getKeys().getBytes(StandardCharsets.UTF_8) : null;
byte[] value = msg.getBody();
OUT data = schema.deserializeKeyAndValue(key, value);
synchronized (lock) {
context.collectWithTimestamp(data, msg.getBornTimestamp());
}
}
found = true;
break;
case NO_MATCHED_MSG:
LOG.debug("No matched message after offset {} for queue {}", offset, mq);
break;
case NO_NEW_MSG:
break;
case OFFSET_ILLEGAL:
LOG.warn("Offset {} is illegal for queue {}", offset, mq);
break;
default:
break;
}
synchronized (lock) {
putMessageQueueOffset(mq, pullResult.getNextBeginOffset());
}
if (found) {
pullTaskContext.setPullNextDelayTimeMillis(0);
} else {
pullTaskContext.setPullNextDelayTimeMillis(delayWhenMessageNotFound);
}
} catch (Exception e) {
throw new RuntimeException(e);
}
}
});
try {
pullConsumerScheduleService.start();
} catch (MQClientException e) {
throw new RuntimeException(e);
}
runningChecker.setRunning(true);
awaitTermination();
}
private void awaitTermination() throws InterruptedException {
while (runningChecker.isRunning()) {
Thread.sleep(50);
}
}
private long getMessageQueueOffset(MessageQueue mq) throws MQClientException {
Long offset = offsetTable.get(mq);
if (restored && offset == null) {
offset = restoredOffsets.get(mq);
}
if (offset == null) {
offset = consumer.fetchConsumeOffset(mq, false);
if (offset < 0) {
String initialOffset = props.getProperty(RocketMQConfig.CONSUMER_OFFSET_RESET_TO, CONSUMER_OFFSET_LATEST);
switch (initialOffset) {
case CONSUMER_OFFSET_EARLIEST:
offset = consumer.minOffset(mq);
break;
case CONSUMER_OFFSET_LATEST:
offset = consumer.maxOffset(mq);
break;
case CONSUMER_OFFSET_TIMESTAMP:
offset = consumer.searchOffset(mq, getLong(props,
RocketMQConfig.CONSUMER_OFFSET_FROM_TIMESTAMP, System.currentTimeMillis()));
break;
default:
throw new IllegalArgumentException("Unknown value for CONSUMER_OFFSET_RESET_TO.");
}
}
}
offsetTable.put(mq, offset);
return offsetTable.get(mq);
}
private void putMessageQueueOffset(MessageQueue mq, long offset) throws MQClientException {
offsetTable.put(mq, offset);
if (!enableCheckpoint) {
consumer.updateConsumeOffset(mq, offset);
}
}
@Override
public void cancel() {
LOG.debug("cancel ...");
runningChecker.setRunning(false);
if (pullConsumerScheduleService != null) {
pullConsumerScheduleService.shutdown();
}
offsetTable.clear();
restoredOffsets.clear();
pendingOffsetsToCommit.clear();
}
@Override
public void close() throws Exception {
LOG.debug("close ...");
try {
cancel();
} finally {
super.close();
}
}
@Override
public void snapshotState(FunctionSnapshotContext context) throws Exception {
if (!runningChecker.isRunning()) {
LOG.debug("snapshotState() called on closed source; returning null.");
return;
}
if (LOG.isDebugEnabled()) {
LOG.debug("Snapshotting state {} ...", context.getCheckpointId());
}
unionOffsetStates.clear();
HashMap<MessageQueue, Long> currentOffsets = new HashMap<>(offsetTable.size());
Set<MessageQueue> assignedQueues = consumer.fetchMessageQueuesInBalance(topic);
offsetTable.entrySet().removeIf(item -> !assignedQueues.contains(item.getKey()));
for (Map.Entry<MessageQueue, Long> entry : offsetTable.entrySet()) {
unionOffsetStates.add(Tuple2.of(entry.getKey(), entry.getValue()));
currentOffsets.put(entry.getKey(), entry.getValue());
}
pendingOffsetsToCommit.put(context.getCheckpointId(), currentOffsets);
if (LOG.isDebugEnabled()) {
LOG.debug("Snapshotted state, last processed offsets: {}, checkpoint id: {}, timestamp: {}",
offsetTable, context.getCheckpointId(), context.getCheckpointTimestamp());
}
}
@Override
public void initializeState(FunctionInitializationContext context) throws Exception {
LOG.debug("initialize State ...");
this.unionOffsetStates = context.getOperatorStateStore().getUnionListState(new ListStateDescriptor<>(
OFFSETS_STATE_NAME, TypeInformation.of(new TypeHint<Tuple2<MessageQueue, Long>>() { })));
this.restored = context.isRestored();
if (restored) {
if (restoredOffsets == null) {
restoredOffsets = new ConcurrentHashMap<>();
}
for (Tuple2<MessageQueue, Long> mqOffsets : unionOffsetStates.get()) {
if (!restoredOffsets.containsKey(mqOffsets.f0) || restoredOffsets.get(mqOffsets.f0) < mqOffsets.f1) {
restoredOffsets.put(mqOffsets.f0, mqOffsets.f1);
}
}
LOG.info("Setting restore state in the consumer. Using the following offsets: {}", restoredOffsets);
} else {
LOG.info("No restore state for the consumer.");
}
}
@Override
public TypeInformation<OUT> getProducedType() {
return schema.getProducedType();
}
@Override
public void notifyCheckpointComplete(long checkpointId) throws Exception {
if (!runningChecker.isRunning()) {
LOG.debug("notifyCheckpointComplete() called on closed source; returning null.");
return;
}
final int posInMap = pendingOffsetsToCommit.indexOf(checkpointId);
if (posInMap == -1) {
LOG.warn("Received confirmation for unknown checkpoint id {}", checkpointId);
return;
}
Map<MessageQueue, Long> offsets = (Map<MessageQueue, Long>)pendingOffsetsToCommit.remove(posInMap);
for (int i = 0; i < posInMap; i++) {
pendingOffsetsToCommit.remove(0);
}
if (offsets == null || offsets.size() == 0) {
LOG.debug("Checkpoint state was empty.");
return;
}
for (Map.Entry<MessageQueue, Long> entry : offsets.entrySet()) {
consumer.updateConsumeOffset(entry.getKey(), entry.getValue());
}
}
}