Skip to content

Commit

Permalink
feature: replace Slices with Pages (for proper pagination based on Re…
Browse files Browse the repository at this point in the history
…dis Cursors)
  • Loading branch information
bsbodden committed Aug 23, 2024
1 parent 2c41414 commit b186349
Show file tree
Hide file tree
Showing 15 changed files with 88 additions and 195 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -403,25 +403,10 @@ public <S extends T> Page<S> findAll(Example<S> example, Pageable pageable) {
SearchStream<S> stream = entityStream.of(example.getProbeType());
var offset = pageable.getPageNumber() * pageable.getPageSize();
var limit = pageable.getPageSize();
Slice<S> slice = stream.filter(example).loadAll().limit(limit, Math.toIntExact(offset))
Page<S> page = stream.filter(example).loadAll().limit(limit, Math.toIntExact(offset))
.toList(pageable, stream.getEntityClass());

if (indexer.indexDefinitionExistsFor(metadata.getJavaType())) {
String searchIndex = indexer.getIndexName(metadata.getJavaType());

SearchOperations<String> searchOps = modulesOperations.opsForSearch(searchIndex);
Query query = new Query(stream.backingQuery());
query.setNoContent();

for (Order order : pageable.getSort()) {
query.setSortBy(order.getProperty(), order.isAscending());
}

SearchResult searchResult = searchOps.search(query);
return pageFromSlice(slice, searchResult.getTotalResults(), pageable.getPageSize());
} else {
return pageFromSlice(slice);
}
return page;
}

/* (non-Javadoc)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;

import static com.redis.om.spring.RedisOMProperties.MAX_SEARCH_RESULTS;
import static com.redis.om.spring.util.ObjectUtils.*;

public class SimpleRedisEnhancedRepository<T, ID> extends SimpleKeyValueRepository<T, ID>
Expand Down Expand Up @@ -405,25 +404,10 @@ public <S extends T> Page<S> findAll(Example<S> example, Pageable pageable) {
SearchStream<S> stream = entityStream.of(example.getProbeType());
var offset = pageable.getPageNumber() * pageable.getPageSize();
var limit = pageable.getPageSize();
Slice<S> slice = stream.filter(example).loadAll().limit(limit, Math.toIntExact(offset))
Page<S> page = stream.filter(example).loadAll().limit(limit, Math.toIntExact(offset))
.toList(pageable, stream.getEntityClass());

if (indexer.indexDefinitionExistsFor(metadata.getJavaType())) {
String searchIndex = indexer.getIndexName(metadata.getJavaType());

SearchOperations<String> searchOps = modulesOperations.opsForSearch(searchIndex);
Query query = new Query(stream.backingQuery());
query.setNoContent();

for (Order order : pageable.getSort()) {
query.setSortBy(order.getProperty(), order.isAscending());
}

SearchResult searchResult = searchOps.search(query);
return pageFromSlice(slice, searchResult.getTotalResults(), pageable.getPageSize());
} else {
return pageFromSlice(slice);
}
return page;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,46 +2,54 @@

import com.google.gson.Gson;
import com.redis.om.spring.convert.MappingRedisOMConverter;
import com.redis.om.spring.ops.search.SearchOperations;
import com.redis.om.spring.util.ObjectUtils;
import org.springframework.data.domain.*;
import org.springframework.data.domain.Sort.Order;
import org.springframework.util.Assert;
import redis.clients.jedis.search.Query;
import redis.clients.jedis.search.SearchResult;
import redis.clients.jedis.search.aggr.AggregationResult;

import java.io.Serializable;
import java.util.Iterator;
import java.util.List;
import java.util.function.Function;

public class AggregationPage<E> implements Slice<E>, Serializable {
public class AggregationPage<E> implements Page<E>, Serializable {
private final transient Pageable pageable;
private final transient Gson gson;
private final Class<E> entityClass;
private final boolean isDocument;
private final transient MappingRedisOMConverter mappingConverter;
private final SearchOperations<String> search;
private List<E> content;
private transient AggregationStream<E> aggregationStream;
private long cursorId = -1;
private AggregationResult aggregationResult;
private Long totalElementCount;

public AggregationPage(AggregationStream<E> aggregationStream, Pageable pageable, Class<E> entityClass, Gson gson,
MappingRedisOMConverter mappingConverter, boolean isDocument) {
MappingRedisOMConverter mappingConverter, boolean isDocument, SearchOperations<String> search) {
this.aggregationStream = aggregationStream;
this.pageable = pageable;
this.entityClass = entityClass;
this.gson = gson;
this.isDocument = isDocument;
this.mappingConverter = mappingConverter;
this.search = search;
}

public AggregationPage(AggregationResult aggregationResult, Pageable pageable, Class<E> entityClass, Gson gson,
MappingRedisOMConverter mappingConverter, boolean isDocument) {
MappingRedisOMConverter mappingConverter, boolean isDocument, SearchOperations<String> search) {
this.aggregationResult = aggregationResult;
this.pageable = pageable;
this.entityClass = entityClass;
this.gson = gson;
this.cursorId = aggregationResult.getCursorId();
this.isDocument = isDocument;
this.mappingConverter = mappingConverter;
this.search = search;
}

@Override
Expand Down Expand Up @@ -86,7 +94,8 @@ public boolean isLast() {

@Override
public boolean hasNext() {
return cursorId == -1 || resolveCursorId() != 0;
// return cursorId == -1 || resolveCursorId() != 0;
return aggregationStream != null ? getNumber() + 1 < getTotalPages() : cursorId == -1 || resolveCursorId() != 0;
}

@Override
Expand All @@ -96,8 +105,34 @@ public Pageable nextPageable() {
}

@Override
public <U> Slice<U> map(Function<? super E, ? extends U> converter) {
return new SliceImpl<>(getConvertedContent(converter), pageable, hasNext());
public int getTotalPages() {
return (getTotalElements() == 0 || pageable.getPageSize() == 0) ?
0 :
(int) Math.ceil((double) getTotalElements() / (double) pageable.getPageSize());
}

@Override
public long getTotalElements() {
if (totalElementCount == null) {
if (aggregationStream != null) {
String baseQuery = aggregationStream.backingQuery();
Query countQuery = (baseQuery.isBlank()) ? new Query() : new Query(baseQuery);
countQuery.setNoContent();
for (Order order : pageable.getSort()) {
countQuery.setSortBy(order.getProperty(), order.isAscending());
}
SearchResult searchResult = search.search(countQuery);
totalElementCount = searchResult.getTotalResults();
} else {
totalElementCount = aggregationResult.getTotalResults(); // not quite sure about this shit
}
}
return totalElementCount;
}

@Override
public <U> Page<U> map(Function<? super E, ? extends U> converter) {
return new PageImpl<>(getConvertedContent(converter));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import com.redis.om.spring.annotations.ReducerFunction;
import com.redis.om.spring.metamodel.MetamodelField;
import com.redis.om.spring.search.stream.aggregations.filters.AggregationFilter;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Slice;
import org.springframework.data.domain.Sort.Order;
import redis.clients.jedis.search.aggr.AggregationResult;

Expand Down Expand Up @@ -50,10 +50,12 @@ public interface AggregationStream<T> {

<R extends T> List<R> toList(Class<?>... contentTypes);

String backingQuery();

// Cursor API
AggregationStream<T> cursor(int i, Duration duration);

<R extends T> Slice<R> toList(Pageable pageRequest, Class<?>... contentTypes);
<R extends T> Page<R> toList(Pageable pageRequest, Class<?>... contentTypes);

<R extends T> Slice<R> toList(Pageable pageRequest, Duration duration, Class<?>... contentTypes);
<R extends T> Page<R> toList(Pageable pageRequest, Duration duration, Class<?>... contentTypes);
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
import com.redis.om.spring.search.stream.aggregations.filters.AggregationFilter;
import com.redis.om.spring.tuple.Tuples;
import com.redis.om.spring.util.ObjectUtils;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Slice;
import org.springframework.data.domain.Sort.Direction;
import org.springframework.data.domain.Sort.Order;
import org.springframework.data.redis.core.convert.ReferenceResolverImpl;
Expand All @@ -34,6 +34,7 @@ public class AggregationStreamImpl<E, T> implements AggregationStream<T> {
private Group currentGroup;
private ReducerFieldPair currentReducer;
private boolean limitSet = false;
private final String query;

@SafeVarargs
public AggregationStreamImpl(String searchIndex, RedisModulesOperations<String> modulesOperations, Gson gson,
Expand All @@ -42,6 +43,7 @@ public AggregationStreamImpl(String searchIndex, RedisModulesOperations<String>
search = modulesOperations.opsForSearch(searchIndex);
aggregation = new AggregationBuilder(query);
isDocument = entityClass.isAnnotationPresent(Document.class);
this.query = query;
this.gson = gson;
this.mappingConverter = new MappingRedisOMConverter(null, new ReferenceResolverImpl(modulesOperations.template()));
createAggregationGroup(fields);
Expand Down Expand Up @@ -376,6 +378,11 @@ public <R extends T> List<R> toList(Class<?>... contentTypes) {
return (List<R>) asList;
}

@Override
public String backingQuery() {
return query;
}

@Override
public AggregationStream<T> cursor(int count, Duration timeout) {
applyCurrentGroupBy();
Expand All @@ -387,18 +394,18 @@ public AggregationStream<T> cursor(int count, Duration timeout) {

@Override
@SuppressWarnings({ "unchecked", "rawtypes" })
public <R extends T> Slice<R> toList(Pageable pageRequest, Class<?>... contentTypes) {
public <R extends T> Page<R> toList(Pageable pageRequest, Class<?>... contentTypes) {
applyCurrentGroupBy();
aggregation.cursor(pageRequest.getPageSize(), 300000);
return new AggregationPage(this, pageRequest, entityClass, gson, mappingConverter, isDocument);
return new AggregationPage(this, pageRequest, entityClass, gson, mappingConverter, isDocument, this.search);
}

@Override
@SuppressWarnings({ "unchecked", "rawtypes" })
public <R extends T> Slice<R> toList(Pageable pageRequest, Duration timeout, Class<?>... contentTypes) {
public <R extends T> Page<R> toList(Pageable pageRequest, Duration timeout, Class<?>... contentTypes) {
applyCurrentGroupBy();
aggregation.cursor(pageRequest.getPageSize(), timeout.toMillis());
return new AggregationPage(this, pageRequest, entityClass, gson, mappingConverter, isDocument);
return new AggregationPage(this, pageRequest, entityClass, gson, mappingConverter, isDocument, this.search);
}

private void applyCurrentGroupBy() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
import org.apache.commons.logging.LogFactory;
import org.springframework.data.annotation.Id;
import org.springframework.data.domain.Example;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Slice;
import org.springframework.data.domain.Sort;
import redis.clients.jedis.search.Document;
import redis.clients.jedis.search.Query;
Expand Down Expand Up @@ -454,7 +454,7 @@ public SearchOperations<String> getSearchOperations() {
}

@Override
public Slice<T> getSlice(Pageable pageable) {
public Page<T> getPage(Pageable pageable) {
throw new UnsupportedOperationException("getPage is not supported on a ReturnFieldSearchStream");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import com.redis.om.spring.search.stream.predicates.SearchFieldPredicate;
import com.redis.om.spring.tuple.Pair;
import org.springframework.data.domain.Example;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Slice;
import org.springframework.data.domain.Sort;
import redis.clients.jedis.search.aggr.SortedField.SortOrder;

Expand Down Expand Up @@ -118,7 +118,7 @@ public interface SearchStream<E> extends BaseStream<E, SearchStream<E>> {

SearchOperations<String> getSearchOperations();

Slice<E> getSlice(Pageable pageable);
Page<E> getPage(Pageable pageable);

<R> SearchStream<E> project(Function<? super E, ? extends R> field);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@

import static com.redis.om.spring.metamodel.MetamodelUtils.getMetamodelForIdField;
import static com.redis.om.spring.util.ObjectUtils.floatArrayToByteArray;
import static com.redis.om.spring.util.ObjectUtils.pageFromSlice;
import static java.util.stream.Collectors.toCollection;

public class SearchStreamImpl<E> implements SearchStream<E> {
Expand Down Expand Up @@ -722,12 +721,12 @@ public SearchOperations<String> getSearchOperations() {
}

@Override
public Slice<E> getSlice(Pageable pageable) {
public Page<E> getPage(Pageable pageable) {
if (pageable.getClass().isAssignableFrom(AggregationPageable.class)) {
resolvedStream = Stream.empty();
AggregationPageable ap = (AggregationPageable) pageable;
AggregationResult ar = search.cursorRead(ap.getCursorId(), pageable.getPageSize());
return new AggregationPage<>(ar, pageable, entityClass, getGson(), mappingConverter, isDocument);
return new AggregationPage<>(ar, pageable, entityClass, getGson(), mappingConverter, isDocument, this.search);
} else {
if (!isStreamResolved()) {
this.sorted(pageable.getSort()).limit(pageable.getPageSize()).skip(Math.toIntExact(pageable.getOffset()));
Expand All @@ -736,9 +735,9 @@ public Slice<E> getSlice(Pageable pageable) {
countQuery.limit(Math.toIntExact(pageable.getOffset() + pageable.getPageSize()), pageable.getPageSize());
SearchResult searchResult = search.search(countQuery);

return new SliceImpl<>(this.resolveStream().toList(), pageable, !searchResult.getDocuments().isEmpty());
return new PageImpl<>(this.resolveStream().toList(), pageable, searchResult.getTotalResults());
} else {
return new SliceImpl<E>(List.of());
return new PageImpl<E>(List.of());
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import com.redis.om.spring.search.stream.predicates.SearchFieldPredicate;
import com.redis.om.spring.tuple.Pair;
import org.springframework.data.domain.Example;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Slice;
import org.springframework.data.domain.Sort;
import redis.clients.jedis.search.aggr.SortedField.SortOrder;

Expand Down Expand Up @@ -313,7 +313,7 @@ public SearchOperations<String> getSearchOperations() {
}

@Override
public Slice<E> getSlice(Pageable pageable) {
public Page<E> getPage(Pageable pageable) {
throw new UnsupportedOperationException("getPage is not supported on a WrappedSearchStream");
}

Expand Down
Loading

0 comments on commit b186349

Please sign in to comment.