Author: Carl Friedrich Bolz-Tereick <[email protected]>
Branch: py3.6
Changeset: r98687:7666a721d2fe
Date: 2020-02-08 17:49 +0100
http://bitbucket.org/pypy/pypy/changeset/7666a721d2fe/

Log:    slightly overengineered code to improve the performance of str.join
        (helps both the list in some situations and the iterator case, but
        the latter is helped more. speedups of >50% when using some other
        iterator)

diff --git a/pypy/objspace/std/test/test_unicodeobject.py 
b/pypy/objspace/std/test/test_unicodeobject.py
--- a/pypy/objspace/std/test/test_unicodeobject.py
+++ b/pypy/objspace/std/test/test_unicodeobject.py
@@ -145,13 +145,20 @@
             assert a == b
             assert type(a) == type(b)
         check(', '.join(['a']), 'a')
+        check(', '.join(['a', 'b']), 'a, b')
         raises(TypeError, ','.join, [b'a'])
-        exc = raises(TypeError, ''.join, ['a', 2, 3])
+        check(','.join(iter([])), '')
+        with raises(TypeError) as exc:
+            ''.join(['a', 2, 3])
+        assert 'sequence item 1' in str(exc.value)
+        with raises(TypeError) as exc:
+            ''.join(iter(['a', 2, 3]))
         assert 'sequence item 1' in str(exc.value)
         # unicode lists
         check(''.join(['\u1234']), '\u1234')
         check(''.join(['\u1234', '\u2345']), '\u1234\u2345')
         check('\u1234'.join(['\u2345', '\u3456']), '\u2345\u1234\u3456')
+        check('x\u1234y'.join(['a', 'b', 'c']), 'ax\u1234ybx\u1234yc')
         # also checking passing a single unicode instead of a list
         check(''.join('\u1234'), '\u1234')
         check(''.join('\u1234\u2345'), '\u1234\u2345')
diff --git a/pypy/objspace/std/unicodeobject.py 
b/pypy/objspace/std/unicodeobject.py
--- a/pypy/objspace/std/unicodeobject.py
+++ b/pypy/objspace/std/unicodeobject.py
@@ -41,6 +41,9 @@
     return rutf8.codepoint_at_pos(utf8, p)
 
 
+joindriver = jit.JitDriver(greens = ['selfisnotempty', 'tpfirst', 'tplist'], 
reds = 'auto',
+                           name='joindriver')
+
 class W_UnicodeObject(W_Root):
     import_from_mixin(StringMethods)
     _immutable_fields_ = ['_utf8', '_length']
@@ -166,9 +169,6 @@
     def _generic_name(self):
         return "str"
 
-    def _generic_name(self):
-        return "str"
-
     def _isupper(self, ch):
         return unicodedb.isupper(ch)
 
@@ -593,15 +593,101 @@
 
         return W_UnicodeObject(expanded, newlen)
 
-    _StringMethods_descr_join = descr_join
-    def descr_join(self, space, w_list):
-        l = space.listview_ascii(w_list)
-        if l is not None and self.is_ascii():
-            if len(l) == 1:
-                return space.newutf8(l[0], len(l[0]))
-            s = self._utf8.join(l)
-            return space.newutf8(s, len(s))
-        return self._StringMethods_descr_join(space, w_list)
+    def _join_utf8_len_w(self, space, w_element, i):
+        try:
+            return space.utf8_len_w(w_element)
+        except OperationError as e:
+            if not e.match(space, space.w_TypeError):
+                raise
+            raise oefmt(space.w_TypeError,
+                        "sequence item %d: expected %s, %T found",
+                        i, self._generic_name(), w_element)
+
+    def _join_ascii(self, space, l):
+        if len(l) == 1:
+            return space.newutf8(l[0], len(l[0]))
+        s = self._utf8.join(l)
+        if self.is_ascii():
+            resultlen = len(s)
+        else:
+            # carefully compute the result length
+            resultlen = len(s) - (len(self._utf8) - self._length) * (len(l) - 
1)
+        return space.newutf8(s, resultlen)
+
+    def _join_from_list(self, space, w_list):
+        list_w = space.listview(w_list)
+        if len(list_w) == 0:
+            return self.EMPTY
+        if len(list_w) == 1:
+            w_s = list_w[0]
+            # only one item, return it if it's not a subclass of str
+            if self._join_return_one(space, w_s):
+                return w_s
+        # the stringmethods implementation makes a copy of the list to
+        # pre-compute the correct size for preallocation. that sounds like the
+        # wrong tradeoff somehow...
+        builder = None
+        # use first element to guess preallocation size
+        w_first = list_w[0]
+        utf8first, lenfirst = self._join_utf8_len_w(space, w_first, 0)
+        prealloc = len(self._utf8) * (len(list_w) - 1) + len(utf8first) * 
len(list_w)
+        builder = rutf8.Utf8StringBuilder(prealloc)
+        builder.append_utf8(utf8first, lenfirst)
+        for i in range(1, len(list_w)):
+            w_element = list_w[i]
+            utf8, l = self._join_utf8_len_w(space, w_element, i)
+            if self._length:
+                builder.append_utf8(self._utf8, self._length)
+            builder.append_utf8(utf8, l)
+        return self.from_utf8builder(builder)
+
+    def _join_from_iterable(self, space, w_iterable):
+        sizehint = space.length_hint(w_iterable, -1)
+
+        # get the first element to guess the preallocation size
+        iterator = space.iteriterable(w_iterable)
+        try:
+            w_first = next(iterator)
+        except StopIteration:
+            return W_UnicodeObject.EMPTY
+
+        utf8first, lenfirst = self._join_utf8_len_w(space, w_first, 0)
+        if sizehint >= 0:
+            prealloc = len(self._utf8) * (sizehint - 1) + len(utf8first) * 
sizehint
+        else:
+            prealloc = len(self._utf8) + len(utf8first)
+
+        # build the result
+        builder = rutf8.Utf8StringBuilder(prealloc)
+        builder.append_utf8(utf8first, lenfirst)
+        size = 1
+        selfisnotempty = self._length != 0
+        tpfirst = space.type(w_first)
+        tplist = space.type(w_iterable)
+        for w_element in iterator:
+            joindriver.jit_merge_point(tpfirst=tpfirst, tplist=tplist, 
selfisnotempty=selfisnotempty)
+            if selfisnotempty:
+                builder.append_utf8(self._utf8, self._length)
+            utf8, l = self._join_utf8_len_w(space, w_element, size)
+            builder.append_utf8(utf8, l)
+            size += 1
+        if size == 1 and self._join_return_one(space, w_first):
+            return w_first
+        return W_UnicodeObject.from_utf8builder(builder)
+
+    def descr_join(self, space, w_iterable):
+        from pypy.objspace.std.listobject import W_ListObject
+        # somewhat overengineered, but it's quite common
+
+        # first, a shortcut for when w_iterable is ascii-only
+        l = space.listview_ascii(w_iterable)
+        if l is not None:
+            return self._join_ascii(space, l)
+
+        if type(w_iterable) is W_ListObject or (isinstance(w_iterable, 
W_ListObject) and
+                                                
space._uses_list_iter(w_iterable)):
+            self._join_from_list(space, w_iterable)
+        return self._join_from_iterable(space, w_iterable)
 
     def _join_return_one(self, space, w_obj):
         return space.is_w(space.type(w_obj), space.w_unicode)
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to