Author: Carl Friedrich Bolz-Tereick <[email protected]>
Branch: jit-generating-extension
Changeset: r98695:c87366a2fc46
Date: 2020-02-11 14:11 +0100
http://bitbucket.org/pypy/pypy/changeset/c87366a2fc46/
Log: hacky experiment: can we write a generating extension for meta-
interpreter (in the spirit of geninterp)? right now it only gets rid
of dispatch overhead, some time-shifting and other optimizations
could be done later too
diff --git a/rpython/jit/codewriter/assembler.py
b/rpython/jit/codewriter/assembler.py
--- a/rpython/jit/codewriter/assembler.py
+++ b/rpython/jit/codewriter/assembler.py
@@ -31,6 +31,7 @@
"""Take the 'ssarepr' representation of the code and assemble
it inside the 'jitcode'. If jitcode is None, make a new one.
"""
+ from rpython.jit.codewriter.genextension import GenExtension
self.setup(ssarepr.name)
ssarepr._insns_pos = []
for insn in ssarepr.insns:
@@ -45,6 +46,8 @@
if self._count_jitcodes < 20: # stop if we have a lot of them
jitcode._dump = format_assembler(ssarepr)
self._count_jitcodes += 1
+ if ssarepr.genextension:
+ GenExtension(self).generate(ssarepr, jitcode)
return jitcode
def setup(self, name):
diff --git a/rpython/jit/codewriter/flatten.py
b/rpython/jit/codewriter/flatten.py
--- a/rpython/jit/codewriter/flatten.py
+++ b/rpython/jit/codewriter/flatten.py
@@ -4,10 +4,11 @@
class SSARepr(object):
- def __init__(self, name):
+ def __init__(self, name, genextension=False):
self.name = name
self.insns = []
self._insns_pos = None # after being assembled
+ self.genextension = genextension
class Label(object):
def __init__(self, name):
@@ -79,11 +80,13 @@
self.cpu = cpu
self._include_all_exc_links = _include_all_exc_links
self.registers = {}
+ genextension = False
if graph:
name = graph.name
+ genextension = getattr(graph.func, "generate_jit_extension", False)
else:
name = '?'
- self.ssarepr = SSARepr(name)
+ self.ssarepr = SSARepr(name, genextension)
def enforce_input_args(self):
inputargs = self.graph.startblock.inputargs
diff --git a/rpython/jit/codewriter/genextension.py
b/rpython/jit/codewriter/genextension.py
new file mode 100644
--- /dev/null
+++ b/rpython/jit/codewriter/genextension.py
@@ -0,0 +1,345 @@
+
+class GenExtension(object):
+ def __init__(self, assembler):
+ self.assembler = assembler
+ self.insns = [None] * len(assembler.insns)
+ for insn, index in assembler.insns.iteritems():
+ self.insns[index] = insn
+
+ def setup(self, ssarepr, jitcode):
+ self.ssarepr = ssarepr
+ self.jitcode = jitcode
+ self.precode = []
+ self.code = []
+ self.globals = {}
+
+ def generate(self, ssarepr, jitcode):
+ from rpython.jit.codewriter.flatten import Label
+ self.setup(ssarepr, jitcode)
+ self.precode.append("def f(self):")
+ self.precode.append(" pc = self.pc")
+ self.precode.append(" while 1:")
+ for index, insn in enumerate(ssarepr.insns):
+ if isinstance(insn[0], Label):
+ continue
+ pc = ssarepr._insns_pos[index]
+ self.code.append("if pc == %s:" % pc)
+ if index == len(self.ssarepr.insns) - 1:
+ nextpc = len(self.jitcode.code)
+ else:
+ nextpc = self.ssarepr._insns_pos[index + 1]
+ lines, needed_orgpc, needed_label = self._emit_instruction(insn,
index, pc, nextpc)
+ for line in lines:
+ self.code.append(" " + line)
+ pcs = self.next_possible_pcs(insn, needed_label, nextpc)
+ if len(pcs) == 0:
+ self.code.append(" assert 0 # unreachable")
+ continue
+ elif len(pcs) == 1:
+ self.code.append(" pc = %s" % pcs[0])
+ else:
+ self.code.append(" pc = self.pc")
+ # do the trick
+ for pc in pcs:
+ self.code.append(" if pc == %s: pc = %s" % (pc, pc))
+ self.code.append(" else: assert 0 # unreachable")
+ self.code.append(" continue")
+ allcode = []
+ allcode.extend(self.precode)
+ for line in self.code:
+ allcode.append(" " * 8 + line)
+ jitcode._genext_source = "\n".join(allcode)
+
+ def _emit_instruction(self, insn, index, pc, nextpc):
+ from rpython.jit.metainterp.pyjitpl import MIFrame
+ from rpython.jit.metainterp.blackhole import signedord
+ lines = []
+ # first, write self.pc
+ lines.append("self.pc = %s" % (nextpc, ))
+ instruction = self.insns[ord(self.jitcode.code[pc])]
+ name, argcodes = instruction.split("/")
+ methodname = 'opimpl_' + name
+ unboundmethod = getattr(MIFrame, methodname).im_func
+ argtypes = unboundmethod.argtypes
+
+ # collect arguments, this is a 'timeshifted' version of the code in
+ # pyjitpl._get_opimpl_method
+ args = []
+ next_argcode = 0
+ code = self.jitcode.code
+ orgpc = pc
+ position = pc
+ position += 1
+ needed_orgpc = False
+ needed_label = None
+ for argtype in argtypes:
+ if argtype == "box": # a box, of whatever type
+ argcode = argcodes[next_argcode]
+ next_argcode = next_argcode + 1
+ if argcode == 'i':
+ value = "self.registers_i[%s]" % (ord(code[position]), )
+ elif argcode == 'c':
+ value = "ConstInt(%s)" % signedord(code[position])
+ elif argcode == 'r':
+ value = "self.registers_r[%s]" % (ord(code[position]), )
+ elif argcode == 'f':
+ value = "self.registers_f[%s]" % (ord(code[position]), )
+ else:
+ raise AssertionError("bad argcode")
+ position += 1
+ elif argtype == "descr" or argtype == "jitcode":
+ assert argcodes[next_argcode] == 'd'
+ next_argcode = next_argcode + 1
+ index = ord(code[position]) | (ord(code[position+1])<<8)
+ import pdb; pdb.set_trace()
+ value = "self.metainterp.staticdata.opcode_descrs[%s]" % index
+ if argtype == "jitcode":
+ assert isinstance(value, JitCode)
+ position += 2
+ elif argtype == "label":
+ assert argcodes[next_argcode] == 'L'
+ next_argcode = next_argcode + 1
+ assert needed_label is None # only one label per instruction
+ needed_label = ord(code[position]) | (ord(code[position+1])<<8)
+ value = str(needed_label)
+ position += 2
+ elif argtype == "boxes": # a list of boxes of some type
+ length = ord(code[position])
+ value = [None] * length
+ self.prepare_list_of_boxes(value, 0, position,
+ argcodes[next_argcode])
+ next_argcode = next_argcode + 1
+ position += 1 + length
+ value = str(value)
+ elif argtype == "boxes2": # two lists of boxes merged into one
+ length1 = ord(code[position])
+ position2 = position + 1 + length1
+ length2 = ord(code[position2])
+ value = [None] * (length1 + length2)
+ self.prepare_list_of_boxes(value, 0, position,
+ argcodes[next_argcode])
+ self.prepare_list_of_boxes(value, length1, position2,
+ argcodes[next_argcode + 1])
+ next_argcode = next_argcode + 2
+ position = position2 + 1 + length2
+ value = str(value)
+ elif argtype == "boxes3": # three lists of boxes merged into one
+ length1 = ord(code[position])
+ position2 = position + 1 + length1
+ length2 = ord(code[position2])
+ position3 = position2 + 1 + length2
+ length3 = ord(code[position3])
+ value = [None] * (length1 + length2 + length3)
+ self.prepare_list_of_boxes(value, 0, position,
+ argcodes[next_argcode])
+ self.prepare_list_of_boxes(value, length1, position2,
+ argcodes[next_argcode + 1])
+ self.prepare_list_of_boxes(value, length1 + length2, position3,
+ argcodes[next_argcode + 2])
+ next_argcode = next_argcode + 3
+ position = position3 + 1 + length3
+ value = str(value)
+ elif argtype == "orgpc":
+ value = str(orgpc)
+ needed_orgpc = True
+ elif argtype == "int":
+ argcode = argcodes[next_argcode]
+ next_argcode = next_argcode + 1
+ if argcode == 'i':
+ value = "self.registers_i[%s].getint()" %
(ord(code[position]), )
+ elif argcode == 'c':
+ value = str(signedord(code[position]))
+ else:
+ raise AssertionError("bad argcode")
+ position += 1
+ elif argtype == "jitcode_position":
+ value = str(position)
+ else:
+ raise AssertionError("bad argtype: %r" % (argtype,))
+ args.append(value)
+ strargs = ", ".join(args)
+
+ num_return_args = len(argcodes) - next_argcode
+ assert num_return_args == 0 or num_return_args == 2
+ if num_return_args:
+ # Save the type of the resulting box. This is needed if there is
+ # a get_list_of_active_boxes(). See comments there.
+ lines.append("self._result_argcode = %r" % (argcodes[next_argcode
+ 1], ))
+ resindex = ord(code[position])
+ if argcodes[next_argcode + 1] == "i":
+ prefix = "self.registers_i[%s] = " % resindex
+ elif argcodes[next_argcode + 1] == "r":
+ prefix = "self.registers_r[%s] = " % resindex
+ elif argcodes[next_argcode + 1] == "f":
+ prefix = "self.registers_f[%s] = " % resindex
+ else:
+ assert 0
+ position += 1
+ else:
+ lines.append("self._result_argcode = 'v'")
+ prefix = ''
+
+ lines.append("%sself.%s(%s)" % (prefix, methodname, strargs))
+ return lines, needed_orgpc, needed_label
+
+ def prepare_list_of_boxes(self, outvalue, startindex, position, argcode):
+ assert argcode in 'IRF'
+ code = self.jitcode.code
+ length = ord(code[position])
+ position += 1
+ for i in range(length):
+ index = ord(code[position+i])
+ if argcode == 'I': reg = "self.registers_i[%s]" % index
+ elif argcode == 'R': reg = "self.registers_r[%s]" % index
+ elif argcode == 'F': reg = "self.registers_f[%s]" % index
+ else: raise AssertionError(argcode)
+ outvalue[startindex+i] = reg
+
+ def next_possible_pcs(self, insn, needed_label, nextpc):
+ if needed_label:
+ return [nextpc, needed_label]
+ if insn[0].endswith("return"):
+ return []
+ if insn[0].endswith("raise"):
+ return []
+ if insn[0] == "switch":
+ import pdb; pdb.set_trace()
+ else:
+ return [nextpc]
+
+
+
+def _get_opimpl_method(name, argcodes):
+ #
+ def handler(self, position):
+ assert position >= 0
+ args = ()
+ next_argcode = 0
+ code = self.bytecode
+ orgpc = position
+ position += 1
+ for argtype in argtypes:
+ if argtype == "box": # a box, of whatever type
+ argcode = argcodes[next_argcode]
+ next_argcode = next_argcode + 1
+ if argcode == 'i':
+ value = self.registers_i[ord(code[position])]
+ elif argcode == 'c':
+ value = ConstInt(signedord(code[position]))
+ elif argcode == 'r':
+ value = self.registers_r[ord(code[position])]
+ elif argcode == 'f':
+ value = self.registers_f[ord(code[position])]
+ else:
+ raise AssertionError("bad argcode")
+ position += 1
+ elif argtype == "descr" or argtype == "jitcode":
+ assert argcodes[next_argcode] == 'd'
+ next_argcode = next_argcode + 1
+ index = ord(code[position]) | (ord(code[position+1])<<8)
+ value = self.metainterp.staticdata.opcode_descrs[index]
+ if argtype == "jitcode":
+ assert isinstance(value, JitCode)
+ position += 2
+ elif argtype == "label":
+ assert argcodes[next_argcode] == 'L'
+ next_argcode = next_argcode + 1
+ value = ord(code[position]) | (ord(code[position+1])<<8)
+ position += 2
+ elif argtype == "boxes": # a list of boxes of some type
+ length = ord(code[position])
+ value = [None] * length
+ self.prepare_list_of_boxes(value, 0, position,
+ argcodes[next_argcode])
+ next_argcode = next_argcode + 1
+ position += 1 + length
+ elif argtype == "boxes2": # two lists of boxes merged into one
+ length1 = ord(code[position])
+ position2 = position + 1 + length1
+ length2 = ord(code[position2])
+ value = [None] * (length1 + length2)
+ self.prepare_list_of_boxes(value, 0, position,
+ argcodes[next_argcode])
+ self.prepare_list_of_boxes(value, length1, position2,
+ argcodes[next_argcode + 1])
+ next_argcode = next_argcode + 2
+ position = position2 + 1 + length2
+ elif argtype == "boxes3": # three lists of boxes merged into one
+ length1 = ord(code[position])
+ position2 = position + 1 + length1
+ length2 = ord(code[position2])
+ position3 = position2 + 1 + length2
+ length3 = ord(code[position3])
+ value = [None] * (length1 + length2 + length3)
+ self.prepare_list_of_boxes(value, 0, position,
+ argcodes[next_argcode])
+ self.prepare_list_of_boxes(value, length1, position2,
+ argcodes[next_argcode + 1])
+ self.prepare_list_of_boxes(value, length1 + length2, position3,
+ argcodes[next_argcode + 2])
+ next_argcode = next_argcode + 3
+ position = position3 + 1 + length3
+ elif argtype == "orgpc":
+ value = orgpc
+ elif argtype == "int":
+ argcode = argcodes[next_argcode]
+ next_argcode = next_argcode + 1
+ if argcode == 'i':
+ value = self.registers_i[ord(code[position])].getint()
+ elif argcode == 'c':
+ value = signedord(code[position])
+ else:
+ raise AssertionError("bad argcode")
+ position += 1
+ elif argtype == "jitcode_position":
+ value = position
+ else:
+ raise AssertionError("bad argtype: %r" % (argtype,))
+ args += (value,)
+ #
+ num_return_args = len(argcodes) - next_argcode
+ assert num_return_args == 0 or num_return_args == 2
+ if num_return_args:
+ # Save the type of the resulting box. This is needed if there is
+ # a get_list_of_active_boxes(). See comments there.
+ self._result_argcode = argcodes[next_argcode + 1]
+ position += 1
+ else:
+ self._result_argcode = 'v'
+ self.pc = position
+ #
+ if not we_are_translated():
+ if self.debug:
+ print '\tpyjitpl: %s(%s)' % (name, ', '.join(map(repr, args))),
+ try:
+ resultbox = unboundmethod(self, *args)
+ except Exception as e:
+ if self.debug:
+ print '-> %s!' % e.__class__.__name__
+ raise
+ if num_return_args == 0:
+ if self.debug:
+ print
+ assert resultbox is None
+ else:
+ if self.debug:
+ print '-> %r' % (resultbox,)
+ assert argcodes[next_argcode] == '>'
+ result_argcode = argcodes[next_argcode + 1]
+ if 'ovf' not in name:
+ assert resultbox.type == {'i': history.INT,
+ 'r': history.REF,
+ 'f':
history.FLOAT}[result_argcode]
+ else:
+ resultbox = unboundmethod(self, *args)
+ #
+ if resultbox is not None:
+ self.make_result_of_lastop(resultbox)
+ elif not we_are_translated():
+ assert self._result_argcode in 'v?' or 'ovf' in name
+ #
+ unboundmethod = getattr(MIFrame, 'opimpl_' + name).im_func
+ argtypes = unrolling_iterable(unboundmethod.argtypes)
+ handler.func_name = 'handler_' + name
+ return handler
+
diff --git a/rpython/jit/codewriter/test/test_genextension.py
b/rpython/jit/codewriter/test/test_genextension.py
new file mode 100644
--- /dev/null
+++ b/rpython/jit/codewriter/test/test_genextension.py
@@ -0,0 +1,56 @@
+from rpython.flowspace.model import Constant
+from rpython.jit.codewriter.flatten import SSARepr, Label, TLabel, Register
+from rpython.jit.codewriter.assembler import Assembler, AssemblerError
+from rpython.rtyper.lltypesystem import lltype, llmemory
+
+def test_assemble_loop():
+ ssarepr = SSARepr("test", genextension=True)
+ i0, i1 = Register('int', 0x16), Register('int', 0x17)
+ ssarepr.insns = [
+ (Label('L1'),),
+ ('goto_if_not_int_gt', i0, Constant(4, lltype.Signed), TLabel('L2')),
+ ('int_add', i1, i0, '->', i1),
+ ('int_sub', i0, Constant(1, lltype.Signed), '->', i0),
+ ('goto', TLabel('L1')),
+ (Label('L2'),),
+ ('int_return', i1),
+ ]
+ assembler = Assembler()
+ jitcode = assembler.assemble(ssarepr)
+ assert jitcode._genext_source == """\
+def f(self):
+ pc = self.pc
+ while 1:
+ if pc == 0:
+ self.pc = 5
+ self._result_argcode = 'v'
+ self.opimpl_goto_if_not_int_gt(self.registers_i[22], ConstInt(4),
16, 0)
+ pc = self.pc
+ if pc == 5: pc = 5
+ if pc == 16: pc = 16
+ else: assert 0 # unreachable
+ continue
+ if pc == 5:
+ self.pc = 9
+ self._result_argcode = 'i'
+ self.registers_i[23] = self.opimpl_int_add(self.registers_i[23],
self.registers_i[22])
+ pc = 9
+ continue
+ if pc == 9:
+ self.pc = 13
+ self._result_argcode = 'i'
+ self.registers_i[22] = self.opimpl_int_sub(self.registers_i[22],
ConstInt(1))
+ pc = 13
+ continue
+ if pc == 13:
+ self.pc = 16
+ self._result_argcode = 'v'
+ self.opimpl_goto(0)
+ pc = 16
+ continue
+ if pc == 16:
+ self.pc = 18
+ self._result_argcode = 'v'
+ self.opimpl_int_return(self.registers_i[23])
+ assert 0 # unreachable"""
+
diff --git a/rpython/rlib/jit.py b/rpython/rlib/jit.py
--- a/rpython/rlib/jit.py
+++ b/rpython/rlib/jit.py
@@ -1292,6 +1292,10 @@
hop.exception_is_here()
return hop.genop(opname, args_v, resulttype=resulttype)
+def warmup_critical_function(func):
+ func.generate_jit_extension = True
+ return func
+
def enter_portal_frame(unique_id):
"""call this when starting to interpret a function. calling this is not
necessary for almost all interpreters. The only exception is stackless
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit