alessandrobenedetti commented on code in PR #3316:
URL: https://github.com/apache/solr/pull/3316#discussion_r2047337900


##########
solr/core/src/java/org/apache/solr/search/join/BlockJoinParentQParser.java:
##########
@@ -78,10 +93,127 @@ protected Query noClausesQuery() throws SyntaxError {
     return new BitSetProducerQuery(getBitSetProducer(parseParentFilter()));
   }
 
-  protected Query createQuery(final Query parentList, Query query, String 
scoreMode)
+  protected Query createQuery(final Query allParents, BooleanQuery 
childrenQuery, String scoreMode)
       throws SyntaxError {
-    return new AllParentsAware(
-        query, getBitSetProducer(parentList), 
ScoreModeParser.parse(scoreMode), parentList);
+      try {
+          List<BooleanClause> childrenClauses = childrenQuery.clauses();
+          if (isByteKnnQuery(childrenClauses)) {
+            BitSetProducer allParentsBitSet = getBitSetProducer(allParents);
+            BooleanQuery parentsFilter = getParentsFilter();
+      
+            KnnByteVectorQuery knnChildrenQuery = (KnnByteVectorQuery) 
childrenClauses.get(0).getQuery();
+            String vectorField = knnChildrenQuery.getField();
+            byte[] queryVector = knnChildrenQuery.getTargetCopy();
+            int topK = knnChildrenQuery.getK();
+      
+            Query acceptedChildren =
+                getChildrenFilter(knnChildrenQuery.getFilter(), parentsFilter, 
allParentsBitSet);
+      
+            Query knnChildren =
+                new DiversifyingChildrenByteKnnVectorQuery(
+                    vectorField, queryVector, acceptedChildren, topK, 
allParentsBitSet);
+            knnChildren = knnChildren.rewrite(req.getSearcher());
+            this.setAppropriateChildrenListingTransformer(req,knnChildren);
+            
+            return new ToParentBlockJoinQuery(
+                knnChildren, allParentsBitSet, 
ScoreModeParser.parse(scoreMode));
+          } else if (isFloatKnnQuery(childrenClauses)) {
+            BitSetProducer allParentsBitSet = getBitSetProducer(allParents);
+            BooleanQuery parentsFilter = getParentsFilter();
+      
+            KnnFloatVectorQuery knnChildrenQuery =
+                (KnnFloatVectorQuery) childrenClauses.get(0).getQuery();
+            String vectorField = knnChildrenQuery.getField();
+            float[] queryVector = knnChildrenQuery.getTargetCopy();
+            int topK = knnChildrenQuery.getK();
+      
+            Query childrenFilter =
+                getChildrenFilter(knnChildrenQuery.getFilter(), parentsFilter, 
allParentsBitSet);
+      
+            Query knnChildren =
+                new DiversifyingChildrenFloatKnnVectorQuery(
+                    vectorField, queryVector, childrenFilter, topK, 
allParentsBitSet);
+            knnChildren = knnChildren.rewrite(req.getSearcher());
+            this.setAppropriateChildrenListingTransformer(req,knnChildren);
+      
+            return new ToParentBlockJoinQuery(
+                knnChildren, allParentsBitSet, 
ScoreModeParser.parse(scoreMode));
+          } else {
+            return new AllParentsAware(
+                childrenQuery,
+                getBitSetProducer(allParents),
+                ScoreModeParser.parse(scoreMode),
+                allParents);
+          }
+      } catch (IOException e) {
+        throw new SolrException(SolrException.ErrorCode.SERVER_ERROR, e);
+      }
+  }
+
+  private void setAppropriateChildrenListingTransformer(SolrQueryRequest 
request, Query knnOnVectorField) throws IOException {
+    QueryLimits currentLimits = QueryLimits.getCurrentLimits();
+    ReturnFields returnFields = currentLimits.getRsp().getReturnFields();
+    DocTransformer originalTransformer = returnFields.getTransformer();
+
+    if (originalTransformer instanceof DocTransformers) {
+      DocTransformers transformers = (DocTransformers) originalTransformer;
+      boolean noChildTransformer = true;
+      for (int i = 0; i < transformers.size() && noChildTransformer; i++) {
+        DocTransformer t = transformers.getTransformer(i);
+        if (t instanceof ChildDocTransformer) {
+          ChildDocTransformer childTransformer = (ChildDocTransformer) t;
+          if (childTransformer.getChildDocSet() == null) {
+            
childTransformer.setChildDocSet(request.getSearcher().getDocSet(knnOnVectorField));
+          }
+          noChildTransformer = false;
+        }
+      }
+    } else {
+      if ((originalTransformer instanceof ChildDocTransformer)) {
+        ChildDocTransformer childTransformer = (ChildDocTransformer) 
originalTransformer;
+        if (childTransformer.getChildDocSet() == null) {
+          
childTransformer.setChildDocSet(request.getSearcher().getDocSet(knnOnVectorField));
+        }
+      }
+    }
+  }
+
+  private boolean isFloatKnnQuery(List<BooleanClause> childrenClauses) {
+    return childrenClauses.size() == 1
+        && 
childrenClauses.get(0).getQuery().getClass().equals(KnnFloatVectorQuery.class);
+  }
+
+  private boolean isByteKnnQuery(List<BooleanClause> childrenClauses) {
+    return childrenClauses.size() == 1
+        && 
childrenClauses.get(0).getQuery().getClass().equals(KnnByteVectorQuery.class);
+  }

Review Comment:
   done, take a look if you like it, it may be a little bit less readable but I 
think it's ok!



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@solr.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscr...@solr.apache.org
For additional commands, e-mail: issues-h...@solr.apache.org

Reply via email to