walterddr commented on a change in pull request #9732: [FLINK-14153][ml] Add to 
BLAS a method that performs DenseMatrix and SparseVector multiplication.
URL: https://github.com/apache/flink/pull/9732#discussion_r337844336
 
 

 ##########
 File path: 
flink-ml-parent/flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/BLAS.java
 ##########
 @@ -131,19 +172,56 @@ public static void gemm(double alpha, DenseMatrix matA, 
boolean transA, DenseMat
        }
 
        /**
-        * y := alpha * A * x + beta * y .
+        * Check the compatibility of matrix and vector sizes in 
<code>gemv</code>.
         */
-       public static void gemv(double alpha, DenseMatrix matA, boolean transA,
-                                                       DenseVector x, double 
beta, DenseVector y) {
+       private static void gemvDimensionCheck(DenseMatrix matA, boolean 
transA, Vector x, Vector y) {
                if (transA) {
-                       assert (matA.numCols() == y.size() && matA.numRows() == 
x.size()) : "Matrix and vector size mismatched.";
+                       Preconditions.checkArgument(matA.numCols() == y.size() 
&& matA.numRows() == x.size(),
+                               "Matrix and vector size mismatched.");
                } else {
-                       assert (matA.numRows() == y.size() && matA.numCols() == 
x.size()) : "Matrix and vector size mismatched.";
+                       Preconditions.checkArgument(matA.numRows() == y.size() 
&& matA.numCols() == x.size(),
+                               "Matrix and vector size mismatched.");
                }
+       }
+
+       /**
+        * y := alpha * A * x + beta * y .
+        */
+       public static void gemv(double alpha, DenseMatrix matA, boolean transA,
+                                                       DenseVector x, double 
beta, DenseVector y) {
+               gemvDimensionCheck(matA, transA, x, y);
                final int m = matA.numRows();
                final int n = matA.numCols();
                final int lda = matA.numRows();
                final String ta = transA ? "T" : "N";
                NATIVE_BLAS.dgemv(ta, m, n, alpha, matA.getData(), lda, 
x.getData(), 1, beta, y.getData(), 1);
        }
+
+       /**
+        * y := alpha * A * x + beta * y .
+        */
+       public static void gemv(double alpha, DenseMatrix matA, boolean transA,
+                                                       SparseVector x, double 
beta, DenseVector y) {
+               gemvDimensionCheck(matA, transA, x, y);
+               final int m = matA.numRows();
+               final int n = matA.numCols();
+               if (transA) {
+                       int start = 0;
+                       for (int i = 0; i < n; i++) {
+                               double s = 0.;
+                               for (int j = 0; j < x.indices.length; j++) {
+                                       s += x.values[j] * matA.data[start + 
x.indices[j]];
+                               }
+                               y.data[i] = beta * y.data[i] + alpha * s;
+                               start += m;
+                       }
+               } else {
+                       scal(beta, y);
+                       for (int i = 0; i < x.indices.length; i++) {
+                               int index = x.indices[i];
+                               double value = alpha * x.values[i];
+                               F2J_BLAS.daxpy(m, value, matA.data, index * m, 
1, y.data, 0, 1);
 
 Review comment:
   could you add the explanation you added in this PR comment into the actual 
code comments? I think it helps others to understand this code in the future.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to