Scan a Redis Cluster

The Redis SCAN command allows you to iterate over the key space. On a single Redis node you can SCAN all keys or just keys matching a pattern. It’s a slow operation: O(N) where N is the number of keys in the database. However, it can be useful when you want to view every item in the database or when there’s no way to find your values without traversing everything. Due to the way that Redis shards data though, it may be difficult to SCAN keys on a Redis Cluster. It is possible but it takes a little more work.

Scan a Redis node

Here’s an example SCAN of a single Redis node using the Jedis library for Java:

private void scanNode(Jedis node) {
    ScanParams scanParams = new ScanParams().count(1000);
    String cursor = ScanParams.SCAN_POINTER_START;
    do {
        ScanResult<String> scanResult = node.scan(cursor, scanParams);
        List<String> keys = scanResult.getResult();
        System.out.println("First key in batch: "  + keys.get(0));
        cursor = scanResult.getCursor();
    } while (!cursor.equals(ScanParams.SCAN_POINTER_START));
}

This example SCANs the single Redis node 1000 keys at a time.

Scan a Redis cluster

When working with clusters rather than single nodes, use a JedisCluster rather than Jedis object. For most operations you can straight swap between these two classes. They both implement the JedisCommands interface which includes scan. So let’s try it:

private void scanBroken(JedisCluster cluster) {
    ScanParams scanParams = new ScanParams().count(1000);
    String cursor = ScanParams.SCAN_POINTER_START;
    do {
        ScanResult<String> scanResult = cluster.scan(cursor, scanParams);
        List<String> keys = scanResult.getResult();
        System.out.println("First key in batch: "  + keys.get(0));
        cursor = scanResult.getCursor();
    } while (!cursor.equals(ScanParams.SCAN_POINTER_START));
}

Running this immediately throws an IllegalArgumentException with message:

Error: Cluster mode only supports SCAN command with MATCH pattern containing hash-tag ( curly-brackets enclosed string )

What’s the problem?

The problem is due to the way that Redis shards data. For each key that you insert, it calculates a hash and inserts the key into the shard corresponding to the hash. Then when you want to retrieve a specific key, it calculates the same hash to determine which shard to retrieve from. This works only if you start with a specific key. When you run a SCAN, you’re retrieving all keys which will be distributed across all shards.

Scan by MATCH pattern

The error message suggests that it will support SCAN command with MATCH pattern. This requires a little design effort up front. If you explicitly control the hash function for a key, you can make sure that everything you want to scan ends up in the same shard.

There’s a good discussion of this on the Redis blog but it’s beyond the scope of this article. Regardless, it still does not allow you to scan everything.

Correct way to scan a Redis cluster

A Redis cluster is just a collection of Redis nodes with the clever hash and shard mechanism applied on top. If we have a cluster, we can still perform operations on individual nodes. Our solution then is to:

  1. Find all nodes in the cluster
  2. For each node, SCAN and assemble a partial result
  3. Merge all partial results to a full result set

A simple implementation in Jedis looks like this:

private void scanAllNodes(JedisCluster cluster) {
    for (ConnectionPool node : cluster.getClusterNodes().values()) {
        try (Jedis j = new Jedis(node.getResource())) {
            scanNode(j); // Single node scan from earlier example
        }
    }
}

Performance improvement

The simple implementation runs the scans sequentially. Given that we’re running each scan on a different node, why not parallelize? Rather than run the scans sequentially, we can submit them as asynchronous tasks to Java’s ExecutorService and run them in parallel.

private void scanParallel(JedisCluster cluster) throws ExecutionException, InterruptedException {
    
    // Create an ExecutorService with enough threads to service all nodes concurrently 
    ExecutorService scanExecutorService = Executors.newFixedThreadPool(JEDIS_CLUSTER_NODES.size());

    // Submit scanNode(j) as an asynchronous job to ExecutorService
    List<Future<?>> results = new ArrayList<>();
    for (ConnectionPool node : cluster.getClusterNodes().values()) {
        try (Jedis j = new Jedis(node.getResource())) {
            Future<?> result = scanExecutorService.submit(() -> scanNode(j));
            results.add(result);
        }
    }
    // At this point, the scans are running in parallel threads
    
    // Wait for all threads to finish
    for (Future<?> result : results) {
        result.get();
    }

    scanExecutorService.shutdown();
}

Return results

In the above examples, we assume that all work is done in the scanNode() method and nothing needs to be returned. However, it’s likely you’ll want to return some results and then merge them together using a MapReduce pattern.

Here’s a somewhat over-engineered solution where the process we want to perform is parameterized as a lambda function and the reduce step is a generic BinaryOperator. I’ve roughly followed Stream.reduce() as a design here. If you need this as a one-off, don’t bother with the complicated generics.

private <T> T scanAllNodes(JedisCluster cluster, Function<List<String>, T> keyFunction, T identity, BinaryOperator<T> accumulator) throws ExecutionException, InterruptedException {
    StopWatch timer = StopWatch.createStarted();
    T accumulatedResult = identity;
    List<Future<T>> results = new ArrayList<>();

    // Scan all nodes in parallel
    for (ConnectionPool node : cluster.getClusterNodes().values()) {
        try (Jedis j = new Jedis(node.getResource())) {
            Future<T> result = scanExecutorService.submit(() -> scan(j, keyFunction, identity, accumulator));
            results.add(result);
        }
    }

    // Await and accumulate all scan results
    for (Future<T> result: results) {
        accumulatedResult = accumulator.apply(accumulatedResult, result.get());
    }

    timer.stop();
    System.out.println("Scanned " + NUM_KEYS + " keys in " + timer.formatTime());

    return accumulatedResult;
}

/**
 * Scan keys on a single Redis node and return accumulated result of a function on the keys
 * @param node Jedis
 * @param keyFunction function to be performed on keys
 * @param identity identity for accumulator
 * @param accumulator combine two keyFunction results into an accumulated result
 * @return results accumulated over complete scan of redis keys
 * @param <T> return type
 */
private <T> T scan(Jedis node, Function<List<String>, T> keyFunction, T identity, BinaryOperator<T> accumulator) {
    T accumulatedResult = identity;
    ScanParams scanParams = new ScanParams().count(SCAN_BATCH);
    String cursor = ScanParams.SCAN_POINTER_START;
    do {
        ScanResult<String> scanResult = node.scan(cursor, scanParams);
        List<String> keys = scanResult.getResult();
        T thisResult = keyFunction.apply(keys);
        accumulatedResult = accumulator.apply(accumulatedResult, thisResult);
        cursor = scanResult.getCursor();
    } while (!cursor.equals(ScanParams.SCAN_POINTER_START));
    return accumulatedResult;
}

Take a look at the full demo on GitHub. The demo also includes a Docker Redis cluster environment that you can run the example code against.

Leave a Reply

Your email address will not be published.