Lazy-load scripts; fall back to eval if evalsha returns NOSCRIPT

This commit is contained in:
Jon Chambers
2021-09-28 12:21:56 -04:00
committed by Jon Chambers
parent f37c76dab1
commit aa4bd92fee
2 changed files with 108 additions and 139 deletions

View File

@@ -5,146 +5,121 @@
package org.whispersystems.textsecuregcm.redis;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import io.lettuce.core.RedisNoScriptException;
import io.lettuce.core.ScriptOutputType;
import io.lettuce.core.api.sync.RedisCommands;
import io.lettuce.core.cluster.SlotHash;
import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
import io.lettuce.core.cluster.models.partitions.RedisClusterNode;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.junit.Test;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.whispersystems.textsecuregcm.tests.util.RedisClusterHelper;
public class ClusterLuaScriptTest extends AbstractRedisClusterTest {
public class ClusterLuaScriptTest {
@RegisterExtension
static final RedisClusterExtension REDIS_CLUSTER_EXTENSION = RedisClusterExtension.builder().build();
@Test
public void testExecuteMovedKey() {
final String key = "key";
final String value = "value";
final FaultTolerantRedisCluster redisCluster = getRedisCluster();
final ClusterLuaScript script = new ClusterLuaScript(redisCluster, "return redis.call(\"SET\", KEYS[1], ARGV[1])",
ScriptOutputType.VALUE);
assertEquals("OK", script.execute(List.of(key), List.of(value)));
assertEquals(value, redisCluster.withCluster(connection -> connection.sync().get(key)));
final int slot = SlotHash.getSlot(key);
final int sourcePort = redisCluster.withCluster(
connection -> connection.sync().nodes(node -> node.hasSlot(slot) && node.is(RedisClusterNode.NodeFlag.UPSTREAM))
.node(0).getUri().getPort());
final RedisCommands<String, String> sourceCommands = redisCluster.withCluster(
connection -> connection.sync().nodes(node -> node.hasSlot(slot) && node.is(RedisClusterNode.NodeFlag.UPSTREAM))
.commands(0));
final RedisCommands<String, String> destinationCommands = redisCluster.withCluster(connection -> connection.sync()
.nodes(node -> !node.hasSlot(slot) && node.is(RedisClusterNode.NodeFlag.UPSTREAM)).commands(0));
destinationCommands.clusterSetSlotImporting(slot, sourceCommands.clusterMyId());
assertEquals("OK", script.execute(List.of(key), List.of(value)));
assertEquals(value, redisCluster.withCluster(connection -> connection.sync().get(key)));
sourceCommands.clusterSetSlotMigrating(slot, destinationCommands.clusterMyId());
assertEquals("OK", script.execute(List.of(key), List.of(value)));
assertEquals(value, redisCluster.withCluster(connection -> connection.sync().get(key)));
for (final String migrateKey : sourceCommands.clusterGetKeysInSlot(slot, Integer.MAX_VALUE)) {
destinationCommands.migrate("127.0.0.1", sourcePort, migrateKey, 0, 1000);
}
assertEquals("OK", script.execute(List.of(key), List.of(value)));
assertEquals(value, redisCluster.withCluster(connection -> connection.sync().get(key)));
destinationCommands.clusterSetSlotNode(slot, destinationCommands.clusterMyId());
assertEquals("OK", script.execute(List.of(key), List.of(value)));
assertEquals(value, redisCluster.withCluster(connection -> connection.sync().get(key)));
}
@Test
public void testExecute() {
void testExecute() {
final RedisAdvancedClusterCommands<String, String> commands = mock(RedisAdvancedClusterCommands.class);
final FaultTolerantRedisCluster mockCluster = RedisClusterHelper.buildMockRedisCluster(commands);
final String script = "return redis.call(\"SET\", KEYS[1], ARGV[1])";
final String sha = "abc123";
final ScriptOutputType scriptOutputType = ScriptOutputType.VALUE;
final List<String> keys = List.of("key");
final List<String> values = List.of("value");
when(commands.scriptLoad(script)).thenReturn(sha);
when(commands.evalsha(any(), any(), any(), any())).thenReturn("OK");
new ClusterLuaScript(mockCluster, script, scriptOutputType).execute(keys, values);
final ClusterLuaScript luaScript = new ClusterLuaScript(mockCluster, script, scriptOutputType);
luaScript.execute(keys, values);
verify(commands).scriptLoad(script);
verify(commands).evalsha(sha, scriptOutputType, keys.toArray(new String[0]), values.toArray(new String[0]));
verify(commands).evalsha(luaScript.getSha(), scriptOutputType, keys.toArray(new String[0]), values.toArray(new String[0]));
verify(commands, never()).eval(anyString(), any(), any(), any());
}
@Test
public void testExecuteNoScriptException() {
final String key = "key";
final String value = "value";
final FaultTolerantRedisCluster redisCluster = getRedisCluster();
final ClusterLuaScript script = new ClusterLuaScript(redisCluster, "return redis.call(\"SET\", KEYS[1], ARGV[1])",
ScriptOutputType.VALUE);
// Remove the scripts created by the CLusterLuaScript constructor
redisCluster.useCluster(connection -> connection.sync().upstream().commands().scriptFlush());
assertEquals("OK", script.execute(List.of(key), List.of(value)));
assertEquals(value, redisCluster.withCluster(connection -> connection.sync().get(key)));
}
@Test
public void testExecuteBinary() {
final RedisAdvancedClusterCommands<String, String> stringCommands = mock(RedisAdvancedClusterCommands.class);
final RedisAdvancedClusterCommands<byte[], byte[]> binaryCommands = mock(RedisAdvancedClusterCommands.class);
final FaultTolerantRedisCluster mockCluster = RedisClusterHelper
.buildMockRedisCluster(stringCommands, binaryCommands);
void testExecuteScriptNotLoaded() {
final RedisAdvancedClusterCommands<String, String> commands = mock(RedisAdvancedClusterCommands.class);
final FaultTolerantRedisCluster mockCluster = RedisClusterHelper.buildMockRedisCluster(commands);
final String script = "return redis.call(\"SET\", KEYS[1], ARGV[1])";
final ScriptOutputType scriptOutputType = ScriptOutputType.VALUE;
final List<String> keys = List.of("key");
final List<String> values = List.of("value");
when(commands.evalsha(any(), any(), any(), any())).thenThrow(new RedisNoScriptException("OH NO"));
final ClusterLuaScript luaScript = new ClusterLuaScript(mockCluster, script, scriptOutputType);
luaScript.execute(keys, values);
verify(commands).eval(script, scriptOutputType, keys.toArray(new String[0]), values.toArray(new String[0]));
verify(commands).evalsha(luaScript.getSha(), scriptOutputType, keys.toArray(new String[0]), values.toArray(new String[0]));
}
@Test
void testExecuteBinaryScriptNotLoaded() {
final RedisAdvancedClusterCommands<String, String> stringCommands = mock(RedisAdvancedClusterCommands.class);
final RedisAdvancedClusterCommands<byte[], byte[]> binaryCommands = mock(RedisAdvancedClusterCommands.class);
final FaultTolerantRedisCluster mockCluster =
RedisClusterHelper.buildMockRedisCluster(stringCommands, binaryCommands);
final String script = "return redis.call(\"SET\", KEYS[1], ARGV[1])";
final String sha = "abc123";
final ScriptOutputType scriptOutputType = ScriptOutputType.VALUE;
final List<byte[]> keys = List.of("key".getBytes(StandardCharsets.UTF_8));
final List<byte[]> values = List.of("value".getBytes(StandardCharsets.UTF_8));
when(stringCommands.scriptLoad(script)).thenReturn(sha);
when(binaryCommands.evalsha(any(), any(), any(), any())).thenReturn("OK".getBytes(StandardCharsets.UTF_8));
when(binaryCommands.evalsha(any(), any(), any(), any())).thenThrow(new RedisNoScriptException("OH NO"));
new ClusterLuaScript(mockCluster, script, scriptOutputType).executeBinary(keys, values);
final ClusterLuaScript luaScript = new ClusterLuaScript(mockCluster, script, scriptOutputType);
luaScript.executeBinary(keys, values);
verify(stringCommands).scriptLoad(script);
verify(binaryCommands).evalsha(sha, scriptOutputType, keys.toArray(new byte[0][]), values.toArray(new byte[0][]));
verify(binaryCommands).eval(script, scriptOutputType, keys.toArray(new byte[0][]), values.toArray(new byte[0][]));
verify(binaryCommands).evalsha(luaScript.getSha(), scriptOutputType, keys.toArray(new byte[0][]), values.toArray(new byte[0][]));
}
@Test
public void testExecuteBinaryNoScriptException() {
final String key = "key";
final String value = "value";
public void testExecuteRealCluster() {
final ClusterLuaScript script = new ClusterLuaScript(REDIS_CLUSTER_EXTENSION.getRedisCluster(),
"return 2;",
ScriptOutputType.INTEGER);
final FaultTolerantRedisCluster redisCluster = getRedisCluster();
for (int i = 0; i < 7; i++) {
assertEquals(2L, script.execute(Collections.emptyList(), Collections.emptyList()));
}
final ClusterLuaScript script = new ClusterLuaScript(redisCluster, "return redis.call(\"SET\", KEYS[1], ARGV[1])",
ScriptOutputType.VALUE);
final int evalCount = REDIS_CLUSTER_EXTENSION.getRedisCluster().withCluster(connection -> {
final String commandStats = connection.sync().info("commandstats");
// Remove the scripts created by the CLusterLuaScript constructor
redisCluster.useCluster(connection -> connection.sync().upstream().commands().scriptFlush());
// We're looking for (and parsing) a line in the command stats that looks like:
//
// ```
// cmdstat_eval:calls=1,usec=44,usec_per_call=44.00
// ```
return Arrays.stream(commandStats.split("\\n"))
.filter(line -> line.startsWith("cmdstat_eval:"))
.map(String::trim)
.map(evalLine -> Arrays.stream(evalLine.substring(evalLine.indexOf(':') + 1).split(","))
.filter(pair -> pair.startsWith("calls="))
.map(callsPair -> Integer.parseInt(callsPair.substring(callsPair.indexOf('=') + 1)))
.findFirst()
.orElse(0))
.findFirst()
.orElse(0);
});
assertArrayEquals("OK".getBytes(StandardCharsets.UTF_8), (byte[]) script
.executeBinary(List.of(key.getBytes(StandardCharsets.UTF_8)), List.of(value.getBytes(StandardCharsets.UTF_8))));
assertEquals(value, redisCluster.withCluster(connection -> connection.sync().get(key)));
assertEquals(1, evalCount);
}
}