Skip to content

Commit

Permalink
FixGroupByPaging (Azure#13688)
Browse files Browse the repository at this point in the history
* FixGroupByPagingScenario

Co-authored-by: Annie Liang <xinlian@microsoft.com>
  • Loading branch information
xinlian12 and Annie Liang authored Aug 2, 2020
1 parent 49eabe1 commit 9b12aa4
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import com.azure.cosmos.models.ModelBridgeInternal;
import com.fasterxml.jackson.databind.node.ObjectNode;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import java.util.ArrayList;
import java.util.HashMap;
Expand Down Expand Up @@ -67,32 +68,51 @@ public static <T extends Resource> Flux<IDocumentQueryExecutionComponent<T>> cre
@Override
public Flux<FeedResponse<T>> drainAsync(int maxPageSize) {
return this.component.drainAsync(maxPageSize)
.collectList()
.map(superList -> {
double requestCharge = 0;
HashMap<String, String> headers = new HashMap<>();
List<Document> documentList = new ArrayList<>();
/* Do groupby stuff here */
// Stage 1:
// Drain the groupings fully from all continuation and all partitions
for (FeedResponse<T> page : superList) {
List<Document> results = (List<Document>) page.getResults();
documentList.addAll(results);
}

this.aggregateGroupings(documentList);

// Stage 2:
// Emit the results from the grouping table page by page

List<Document> groupByResults = this.groupingTable.drain(maxPageSize);

headers.put(HttpConstants.HttpHeaders.REQUEST_CHARGE, Double.toString(requestCharge));
FeedResponse<Document> frp =
BridgeInternal.createFeedResponse(groupByResults, headers);

return (FeedResponse<T>) frp;
}).flux();
.collectList()
.map(superList -> {
double requestCharge = 0;
HashMap<String, String> headers = new HashMap<>();
List<Document> documentList = new ArrayList<>();
/* Do groupBy stuff here */
// Stage 1:
// Drain the groupings fully from all continuation and all partitions
for (FeedResponse<T> page : superList) {
List<Document> results = (List<Document>) page.getResults();
documentList.addAll(results);
requestCharge += page.getRequestCharge();
}

this.aggregateGroupings(documentList);

// Stage 2:
// Emit the results from the grouping table page by page
return createFeedResponseFromGroupingTable(maxPageSize, requestCharge);
}).expand(tFeedResponse -> {
// For groupBy query, we have already drained everything for the first page request
// so for following requests, we will just need to drain page by page from the grouping table
FeedResponse<T> response = createFeedResponseFromGroupingTable(maxPageSize, 0);
if (response == null) {
return Mono.empty();
}
return Mono.just(response);
});
}

@SuppressWarnings("unchecked") // safe to upcast
private FeedResponse<T> createFeedResponseFromGroupingTable(int pageSize, double requestCharge) {
if (this.groupingTable != null) {
List<Document> groupByResults = groupingTable.drain(pageSize);
if (groupByResults.size() == 0) {
return null;
}

HashMap<String, String> headers = new HashMap<>();
headers.put(HttpConstants.HttpHeaders.REQUEST_CHARGE, Double.toString(requestCharge));
FeedResponse<Document> frp = BridgeInternal.createFeedResponse(groupByResults, headers);
return (FeedResponse<T>) frp;
}

return null;
}

private void aggregateGroupings(List<Document> superList) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.commons.lang3.tuple.Triple;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Factory;
import org.testng.annotations.Test;

Expand All @@ -27,11 +29,13 @@
import java.util.Map;
import java.util.Random;
import java.util.UUID;
import java.util.function.Function;
import java.util.stream.Collectors;

import static org.assertj.core.api.Assertions.assertThat;

public class GroupByQueryTests extends TestSuiteBase {
private final static int INSERT_DOCUMENTS_CNT = 40;
List<Person> personList;
private CosmosAsyncContainer createdCollection;
private ArrayList<InternalObjectNode> docs = new ArrayList<>();
Expand Down Expand Up @@ -71,35 +75,61 @@ private static int getRandomAge(Random rand) {
return rand.nextInt(100);
}

@Test(groups = {"simple"}, timeOut = TIMEOUT)
public void queryDocuments() {
@DataProvider
public static Object[] groupByConfigProvider() {
// left: groupBy property
// right: maxItemCount
return new Object[]{
Triple.of(
"city",
new Function<Person, City>() {
@Override
public City apply(Person person) {
return person.getCity();
}
}, 35),
Triple.of(
"guid",
new Function<Person, UUID>() {
@Override
public UUID apply(Person person) {
return person.getGuid();
}
}, INSERT_DOCUMENTS_CNT/2) // this is to make sure we are testing paging scenario
};
}

@Test(groups = {"simple"}, dataProvider = "groupByConfigProvider", timeOut = TIMEOUT)
public void queryDocuments(Triple<String, Function<Person, Object>, Integer> groupByConfig) {
boolean qmEnabled = true;

String query = "SELECT sum(c.age) as sum_age, c.city FROM c group by c.city";
String query =
String.format("SELECT sum(c.age) as sum_age, c.%s FROM c group by c.%s", groupByConfig.getLeft(), groupByConfig.getLeft());
CosmosQueryRequestOptions options = new CosmosQueryRequestOptions();
ModelBridgeInternal.setQueryRequestOptionsMaxItemCount(options, 35);
ModelBridgeInternal.setQueryRequestOptionsMaxItemCount(options, groupByConfig.getRight());
options.setQueryMetricsEnabled(qmEnabled);
options.setMaxDegreeOfParallelism(2);
CosmosPagedFlux<JsonNode> queryObservable = createdCollection.queryItems(query,
options,
JsonNode.class);
Map<City, Integer> resultMap = personList.stream()
.collect(Collectors.groupingBy(Person::getCity,
Map<Object, Integer> resultMap = personList.stream()
.collect(Collectors.groupingBy(groupByConfig.getMiddle(),
Collectors.summingInt(Person::getAge)));

List<Document> expectedDocumentsList = new ArrayList<>();
resultMap.forEach((city, sum) ->
resultMap.forEach((groupByObj, sum) ->
{
Document d = new Document();
d.set("sum_age", sum);
d.set("city", city);
d.set(groupByConfig.getLeft(), groupByObj);
expectedDocumentsList.add(d);
});


List<FeedResponse<JsonNode>> queryResultPages = queryObservable.byPage().collectList().block();

List<JsonNode> queryResults = new ArrayList<>();

queryResultPages
.forEach(feedResponse -> queryResults.addAll(feedResponse.getResults()));

Expand All @@ -108,18 +138,21 @@ public void queryDocuments() {
for (int i = 0; i < expectedDocumentsList.size(); i++) {
assertThat(expectedDocumentsList.get(i).toString().equals(queryResults.get(i).toString()));
}

double totalRequestCharge = queryResultPages.stream().collect(Collectors.summingDouble(FeedResponse::getRequestCharge));
assertThat(totalRequestCharge).isGreaterThan(0);
}

public void bulkInsert() {
generateTestData();
generateTestData(INSERT_DOCUMENTS_CNT);
voidBulkInsertBlocking(createdCollection, docs);
}

public void generateTestData() {
public void generateTestData(int documentCnt) {
personList = new ArrayList<>();
Random rand = new Random();
ObjectMapper mapper = new ObjectMapper();
for (int i = 0; i < 40; i++) {
for (int i = 0; i < documentCnt; i++) {
Person person = getRandomPerson(rand);
try {
docs.add(new InternalObjectNode(mapper.writeValueAsString(person)));
Expand Down Expand Up @@ -153,13 +186,6 @@ public Person getRandomPerson(Random rand) {
return new Person(name, city, income, people, age, pet, guid);
}

void generateQueryConfig(){
Map<City, Integer> resultMap = personList.stream()
.collect(Collectors.groupingBy(Person::getCity,
Collectors.summingInt(Person::getAge)));

}

@AfterClass(groups = {"simple"}, timeOut = SHUTDOWN_TIMEOUT, alwaysRun = true)
public void afterClass() {
safeClose(client);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,13 @@ public City getCity() {
public int getAge() {
return age;
}

/**
* Getter for property 'guid'.
*
* @return Value for property 'guid'.
*/
public UUID getGuid() {
return guid;
}
}

0 comments on commit 9b12aa4

Please sign in to comment.