From a3f1a282ffb7e9d459ba3d1135536504bd89f597 Mon Sep 17 00:00:00 2001
From: Yonghong Song <yhs@fb.com>
Date: Fri, 22 Sep 2017 13:52:55 -0700
Subject: [PATCH] bpf: add support for neg insn and change format of bswap insn

[Alexei,

 This patch demonstrates that we can support
 bswap/neg like:
   reg1 = (be16) (u16) reg1
   reg1 = -reg1
 At IR level, we already have constraints to ensure
 that src reg the same as dst reg. The only issue
 is the assembler (from .s to .o) where the constraint
 check in BPFInstrInfo.td is not effective.
 I added additional check in BPFAsmParser.cpp which
 can help warn user if the src/dst not the same
 for the above insns. Without this check, wrong code
 will be generated.

 My previous experiment tries to use the same dst
 register in two different places to enforce the
 same src/dst register. Now with additional check
 in BPFAsmParser, we can use both src and dst registers
 in asmstring and still guarantee they point to the
 same register.

 From llvm point of view, the following format is the
 simplest, and requires no special precheck insn matching
 for assembler:
   be16 <reg>
   neg  <reg>
 But this syntax is not consistent with all other arith
 syntax we having now.

 Therefore, let us go back to the C-style syntax.
 Not sure we want
   "reg1 = (be16) (u16) reg1" or
   "reg1 = (be16) reg1".
 Maybe the second choice is good enough?

 Regarding for 32bit syntax, with Jiong's patch, 32bit ALU
 operations will have syntax with registers "w0, w1, ..., w10"
 instead of "(u32)r0". I guess we can deal with this
 in verifier later when 32bit support matures.

 Once we have consensus, I can send email out to Edward
 about the proposal and the reason.

]

Signed-off-by: Yonghong Song <yhs@fb.com>
---
 lib/Target/BPF/AsmParser/BPFAsmParser.cpp | 51 ++++++++++++++++++++++++++-----
 lib/Target/BPF/BPFInstrFormats.td         |  1 +
 lib/Target/BPF/BPFInstrInfo.td            | 28 ++++++++++++++---
 3 files changed, 68 insertions(+), 12 deletions(-)

diff --git a/lib/Target/BPF/AsmParser/BPFAsmParser.cpp b/lib/Target/BPF/AsmParser/BPFAsmParser.cpp
index d00200c..683f7fd 100644
--- a/lib/Target/BPF/AsmParser/BPFAsmParser.cpp
+++ b/lib/Target/BPF/AsmParser/BPFAsmParser.cpp
@@ -30,6 +30,8 @@ struct BPFOperand;
 class BPFAsmParser : public MCTargetAsmParser {
   SMLoc getLoc() const { return getParser().getTok().getLoc(); }
 
+  bool PreMatchCheck(OperandVector &Operands);
+
   bool MatchAndEmitInstruction(SMLoc IDLoc, unsigned &Opcode,
                                OperandVector &Operands, MCStreamer &Out,
                                uint64_t &ErrorInfo,
@@ -225,9 +227,6 @@ public:
         .Case("*", true)
         .Case("exit", true)
         .Case("lock", true)
-        .Case("bswap64", true)
-        .Case("bswap32", true)
-        .Case("bswap16", true)
         .Case("ld_pseudo", true)
         .Default(false);
   }
@@ -239,6 +238,9 @@ public:
         .Case("u32", true)
         .Case("u16", true)
         .Case("u8", true)
+        .Case("be64", true)
+        .Case("be32", true)
+        .Case("be16", true)
         .Case("goto", true)
         .Case("ll", true)
         .Case("skb", true)
@@ -252,6 +254,41 @@ public:
 #define GET_MATCHER_IMPLEMENTATION
 #include "BPFGenAsmMatcher.inc"
 
+bool BPFAsmParser::PreMatchCheck(OperandVector &Operands) {
+
+  if (Operands.size() == 4) {
+    // check reg1 = -reg2, reg1 must be the same as reg2
+    BPFOperand &Op0 = (BPFOperand &)*Operands[0];
+    BPFOperand &Op1 = (BPFOperand &)*Operands[1];
+    BPFOperand &Op2 = (BPFOperand &)*Operands[2];
+    BPFOperand &Op3 = (BPFOperand &)*Operands[3];
+    if (Op0.isReg() && Op1.isToken() && Op2.isToken() && Op3.isReg()
+        && Op1.getToken() == "=" && Op2.getToken() == "-"
+        && Op0.getReg() != Op3.getReg())
+      return true;
+  } else if (Operands.size() == 9) {
+    // check reg1 = (be16) (u16) reg2, reg1 must be the same as reg2
+    BPFOperand &Op0 = (BPFOperand &)*Operands[0];
+    BPFOperand &Op1 = (BPFOperand &)*Operands[1];
+    BPFOperand &Op2 = (BPFOperand &)*Operands[2];
+    BPFOperand &Op3 = (BPFOperand &)*Operands[3];
+    BPFOperand &Op4 = (BPFOperand &)*Operands[4];
+    BPFOperand &Op5 = (BPFOperand &)*Operands[5];
+    BPFOperand &Op6 = (BPFOperand &)*Operands[6];
+    BPFOperand &Op7 = (BPFOperand &)*Operands[7];
+    BPFOperand &Op8 = (BPFOperand &)*Operands[8];
+    if (Op0.isReg() && Op1.isToken() && Op2.isToken() && Op3.isToken()
+        && Op4.isToken() &&  Op5.isToken() && Op6.isToken()
+        && Op7.isToken() && Op8.isReg()
+        && Op1.getToken() == "=" && Op2.getToken() == "("
+        && Op4.getToken() == ")" && Op5.getToken() == "("
+        && Op7.getToken() == ")" && Op0.getReg() != Op8.getReg())
+      return true;
+  }
+
+  return false;
+}
+
 bool BPFAsmParser::MatchAndEmitInstruction(SMLoc IDLoc, unsigned &Opcode,
                                            OperandVector &Operands,
                                            MCStreamer &Out, uint64_t &ErrorInfo,
@@ -259,6 +296,9 @@ bool BPFAsmParser::MatchAndEmitInstruction(SMLoc IDLoc, unsigned &Opcode,
   MCInst Inst;
   SMLoc ErrorLoc;
 
+  if (PreMatchCheck(Operands))
+    return Error(IDLoc, "additional inst constraint not met");
+
   switch (MatchInstructionImpl(Operands, Inst, ErrorInfo, MatchingInlineAsm)) {
   default:
     break;
@@ -324,13 +364,8 @@ BPFAsmParser::parseOperandAsOperator(OperandVector &Operands) {
   switch (getLexer().getKind()) {
   case AsmToken::Minus:
   case AsmToken::Plus: {
-    StringRef Name = getLexer().getTok().getString();
-
     if (getLexer().peekTok().is(AsmToken::Integer))
       return MatchOperand_NoMatch;
-
-    getLexer().Lex();
-    Operands.push_back(BPFOperand::createToken(Name, S));
   }
   // Fall through.
 
diff --git a/lib/Target/BPF/BPFInstrFormats.td b/lib/Target/BPF/BPFInstrFormats.td
index 1e3bc3b..92d4a62 100644
--- a/lib/Target/BPF/BPFInstrFormats.td
+++ b/lib/Target/BPF/BPFInstrFormats.td
@@ -38,6 +38,7 @@ def BPF_OR   : BPFArithOp<0x4>;
 def BPF_AND  : BPFArithOp<0x5>;
 def BPF_LSH  : BPFArithOp<0x6>;
 def BPF_RSH  : BPFArithOp<0x7>;
+def BPF_NEG  : BPFArithOp<0x8>;
 def BPF_XOR  : BPFArithOp<0xa>;
 def BPF_MOV  : BPFArithOp<0xb>;
 def BPF_ARSH : BPFArithOp<0xc>;
diff --git a/lib/Target/BPF/BPFInstrInfo.td b/lib/Target/BPF/BPFInstrInfo.td
index e1f233e..fc7979b 100644
--- a/lib/Target/BPF/BPFInstrInfo.td
+++ b/lib/Target/BPF/BPFInstrInfo.td
@@ -232,6 +232,26 @@ let isAsCheapAsAMove = 1 in {
   defm DIV : ALU<BPF_DIV, "/=", udiv>;
 }
 
+class NEG_RR<BPFOpClass Class, BPFArithOp Opc,
+             dag outs, dag ins, string asmstr, list<dag> pattern>
+    : TYPE_ALU_JMP<Opc.Value, 0, outs, ins, asmstr, pattern> {
+  bits<4> dst;
+  bits<4> src;
+
+  let Inst{55-52} = src;
+  let Inst{51-48} = dst;
+  let BPFClass = Class;
+}
+
+let Constraints = "$dst = $src", isAsCheapAsAMove = 1 in {
+  def NEG_64: NEG_RR<BPF_ALU64, BPF_NEG, (outs GPR:$dst), (ins GPR:$src),
+                     "$dst = -$src",
+                     [(set GPR:$dst, (ineg i64:$src))]>;
+  def NEG_32: NEG_RR<BPF_ALU, BPF_NEG, (outs GPR32:$dst), (ins GPR32:$src),
+                     "$dst = -$src",
+                     [(set GPR32:$dst, (ineg i32:$src))]>;
+}
+
 class LD_IMM64<bits<4> Pseudo, string OpcodeStr>
     : TYPE_LD_ST<BPF_IMM.Value, BPF_DW.Value,
                  (outs GPR:$dst),
@@ -488,7 +508,7 @@ class BSWAP<bits<32> SizeOp, string OpcodeStr, list<dag> Pattern>
     : TYPE_ALU_JMP<BPF_END.Value, BPF_TO_BE.Value,
                    (outs GPR:$dst),
                    (ins GPR:$src),
-                   !strconcat(OpcodeStr, "\t$dst"),
+                   "$dst = (be"#OpcodeStr#") (u"#OpcodeStr#") $src",
                    Pattern> {
   bits<4> dst;
 
@@ -498,9 +518,9 @@ class BSWAP<bits<32> SizeOp, string OpcodeStr, list<dag> Pattern>
 }
 
 let Constraints = "$dst = $src" in {
-def BSWAP16 : BSWAP<16, "bswap16", [(set GPR:$dst, (srl (bswap GPR:$src), (i64 48)))]>;
-def BSWAP32 : BSWAP<32, "bswap32", [(set GPR:$dst, (srl (bswap GPR:$src), (i64 32)))]>;
-def BSWAP64 : BSWAP<64, "bswap64", [(set GPR:$dst, (bswap GPR:$src))]>;
+def BSWAP16 : BSWAP<16, "16", [(set GPR:$dst, (srl (bswap GPR:$src), (i64 48)))]>;
+def BSWAP32 : BSWAP<32, "32", [(set GPR:$dst, (srl (bswap GPR:$src), (i64 32)))]>;
+def BSWAP64 : BSWAP<64, "64", [(set GPR:$dst, (bswap GPR:$src))]>;
 }
 
 let Defs = [R0, R1, R2, R3, R4, R5], Uses = [R6], hasSideEffects = 1,
-- 
2.9.5

