migrate token bucket redis record format from json to hash: phase 1

This commit is contained in:
Sergey Skrobotov
2023-03-15 14:32:14 -07:00
parent ebf8aa7b15
commit 483e444174
6 changed files with 277 additions and 43 deletions

View File

@@ -10,6 +10,7 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.whispersystems.textsecuregcm.util.redis.RedisLuaScriptSandbox.tail;
import java.util.List;
import java.util.Map;
/**
* This class is to be extended with implementations of Redis commands as needed.
@@ -28,23 +29,79 @@ public class BaseRedisCommandsHandler implements RedisCommandsHandler {
yield get(args.get(0).toString());
}
case "DEL" -> {
assertTrue(args.size() > 1);
yield del(args.get(0).toString());
assertTrue(args.size() >= 1);
yield del(args.stream().map(Object::toString).toList());
}
case "HSET" -> {
assertTrue(args.size() >= 3);
yield hset(args.get(0).toString(), args.get(1).toString(), args.get(2).toString(), tail(args, 3));
}
case "HGET" -> {
assertEquals(2, args.size());
yield hget(args.get(0).toString(), args.get(1).toString());
}
case "PEXPIRE" -> {
assertEquals(2, args.size());
yield pexpire(args.get(0).toString(), Double.valueOf(args.get(1).toString()).longValue(), tail(args, 2));
}
case "TYPE" -> {
assertEquals(1, args.size());
yield type(args.get(0).toString());
}
case "RPUSH" -> {
assertTrue(args.size() > 1);
yield push(false, args.get(0).toString(), tail(args, 1));
}
case "LPUSH" -> {
assertTrue(args.size() > 1);
yield push(true, args.get(0).toString(), tail(args, 1));
}
case "RPOP" -> {
assertEquals(2, args.size());
yield pop(false, args.get(0).toString(), Double.valueOf(args.get(1).toString()).intValue());
}
case "LPOP" -> {
assertEquals(2, args.size());
yield pop(true, args.get(0).toString(), Double.valueOf(args.get(1).toString()).intValue());
}
default -> other(command, args);
};
}
public Object[] pop(final boolean left, final String key, final int count) {
return new Object[count];
}
public Object push(final boolean left, final String key, final List<Object> values) {
return 0;
}
public Object type(final String key) {
return Map.of("ok", "none");
}
public Object pexpire(final String key, final long ttlMillis, final List<Object> args) {
return 0;
}
public Object hset(final String key, final String field, final String value, final List<Object> other) {
return "OK";
}
public Object hget(final String key, final String field) {
return null;
}
public Object set(final String key, final String value, final List<Object> tail) {
return "OK";
}
public String get(final String key) {
return null;
}
public int del(final String key) {
public int del(final List<String> keys) {
return 0;
}

View File

@@ -28,7 +28,12 @@ public class RedisLuaScriptSandbox {
function redis_call(...)
-- variable name needs to match the one used in the `L.setGlobal()` call
-- method name needs to match method name of the Java class
return proxy:redisCall(arg)
local result = proxy:redisCall(arg)
if type(result) == "userdata" then
return java.luaify(result)
else
return result
end
end
function json_encode(obj)

View File

@@ -6,13 +6,15 @@
package org.whispersystems.textsecuregcm.util.redis;
import java.time.Clock;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import javax.annotation.Nullable;
public class SimpleCacheCommandsHandler extends BaseRedisCommandsHandler {
public record Entry(String value, long expirationEpochMillis) {
public record Entry(Object value, long expirationEpochMillis) {
}
private final Map<String, Entry> cache = new ConcurrentHashMap<>();
@@ -32,20 +34,106 @@ public class SimpleCacheCommandsHandler extends BaseRedisCommandsHandler {
@Override
public String get(final String key) {
return getIfNotExpired(key, String.class);
}
@Override
public int del(final List<String> key) {
return key.stream()
.mapToInt(k -> cache.remove(k) != null ? 1 : 0)
.sum();
}
@SuppressWarnings("unchecked")
@Override
public Object hset(final String key, final String field, final String value, final List<Object> other) {
Map<Object, Object> map = getIfNotExpired(key, Map.class);
if (map == null) {
map = new ConcurrentHashMap<>();
cache.put(key, new Entry(map, Long.MAX_VALUE));
}
map.put(field, value);
return "OK";
}
@Override
public Object hget(final String key, final String field) {
final Map<?, ?> map = getIfNotExpired(key, Map.class);
return map == null ? null : map.get(field);
}
@SuppressWarnings("unchecked")
@Override
public Object push(final boolean left, final String key, final List<Object> values) {
LinkedList<Object> list = getIfNotExpired(key, LinkedList.class);
if (list == null) {
list = new LinkedList<>();
cache.put(key, new Entry(list, Long.MAX_VALUE));
}
for (Object v: values) {
if (left) {
list.addFirst(v.toString());
} else {
list.addLast(v.toString());
}
}
return list.size();
}
@SuppressWarnings("unchecked")
@Override
public Object[] pop(final boolean left, final String key, final int count) {
final Object[] result = new String[count];
final LinkedList<Object> list = getIfNotExpired(key, LinkedList.class);
if (list == null) {
return result;
}
for (int i = 0; i < Math.min(count, list.size()); i++) {
result[i] = left ? list.removeFirst() : list.removeLast();
}
return result;
}
@Override
public Object pexpire(final String key, final long ttlMillis, final List<Object> args) {
final Entry e = cache.get(key);
if (e == null) {
return 0;
}
final Entry updated = new Entry(e.value(), clock.millis() + ttlMillis);
cache.put(key, updated);
return 1;
}
@Override
public Object type(final String key) {
final Object o = getIfNotExpired(key, Object.class);
final String type;
if (o == null) {
type = "none";
} else if (o.getClass() == String.class) {
type = "string";
} else if (Map.class.isAssignableFrom(o.getClass())) {
type = "hash";
} else if (List.class.isAssignableFrom(o.getClass())) {
type = "list";
} else {
throw new IllegalArgumentException("Unsupported value type: " + o.getClass());
}
return Map.of("ok", type);
}
@Nullable
protected <T> T getIfNotExpired(final String key, final Class<T> expectedType) {
final Entry entry = cache.get(key);
if (entry == null) {
return null;
}
if (entry.expirationEpochMillis() < clock.millis()) {
del(key);
del(List.of(key));
return null;
}
return entry.value();
}
@Override
public int del(final String key) {
return cache.remove(key) != null ? 1 : 0;
return expectedType.cast(entry.value());
}
protected long resolveExpirationEpochMillis(final List<Object> args) {