From 2516b7bad0cf837a3418e1a21bdf0c4b96d7e432 Mon Sep 17 00:00:00 2001 From: tzickel Date: Sat, 16 Jun 2018 18:15:51 +0300 Subject: [PATCH] Added support for streaming multipart decoding --- requests_toolbelt/__init__.py | 4 +- requests_toolbelt/multipart/__init__.py | 3 +- requests_toolbelt/multipart/decoder.py | 216 +++++++++++++++++++++++- tests/test_multipart_decoder.py | 64 ++++++- 4 files changed, 279 insertions(+), 8 deletions(-) diff --git a/requests_toolbelt/__init__.py b/requests_toolbelt/__init__.py index 55362461..7719fe27 100644 --- a/requests_toolbelt/__init__.py +++ b/requests_toolbelt/__init__.py @@ -13,7 +13,8 @@ from .auth.guess import GuessAuth from .multipart import ( MultipartEncoder, MultipartEncoderMonitor, MultipartDecoder, - ImproperBodyPartContentException, NonMultipartContentTypeException + ImproperBodyPartContentException, NonMultipartContentTypeException, + MultipartStreamDecoder ) from .streaming_iterator import StreamingIterator from .utils.user_agent import user_agent @@ -31,4 +32,5 @@ 'StreamingIterator', 'user_agent', 'ImproperBodyPartContentException', 'NonMultipartContentTypeException', '__title__', '__authors__', '__license__', '__copyright__', '__version__', '__version_info__', + 'MultipartStreamDecoder' ] diff --git a/requests_toolbelt/multipart/__init__.py b/requests_toolbelt/multipart/__init__.py index 4bc49660..9f01a811 100644 --- a/requests_toolbelt/multipart/__init__.py +++ b/requests_toolbelt/multipart/__init__.py @@ -9,7 +9,7 @@ """ from .encoder import MultipartEncoder, MultipartEncoderMonitor -from .decoder import MultipartDecoder +from .decoder import MultipartDecoder, MultipartStreamDecoder from .decoder import ImproperBodyPartContentException from .decoder import NonMultipartContentTypeException @@ -22,6 +22,7 @@ 'MultipartEncoder', 'MultipartEncoderMonitor', 'MultipartDecoder', + 'MultipartStreamDecoder', 'ImproperBodyPartContentException', 'NonMultipartContentTypeException', '__title__', diff --git a/requests_toolbelt/multipart/decoder.py b/requests_toolbelt/multipart/decoder.py index 19d86af9..17a81f0b 100644 --- a/requests_toolbelt/multipart/decoder.py +++ b/requests_toolbelt/multipart/decoder.py @@ -107,11 +107,12 @@ def __init__(self, content, content_type, encoding='utf-8'): self.encoding = encoding #: Parsed parts of the multipart response body self.parts = tuple() - self._find_boundary() + self.boundary = MultipartDecoder._find_boundary(content_type, encoding) self._parse_body(content) - def _find_boundary(self): - ct_info = tuple(x.strip() for x in self.content_type.split(';')) + @staticmethod + def _find_boundary(content_type, encoding): + ct_info = tuple(x.strip() for x in content_type.split(';')) mimetype = ct_info[0] if mimetype.split('/')[0].lower() != 'multipart': raise NonMultipartContentTypeException( @@ -123,7 +124,8 @@ def _find_boundary(self): '=' ) if attr.lower() == 'boundary': - self.boundary = encode_with(value.strip('"'), self.encoding) + boundary = encode_with(value.strip('"'), encoding) + return boundary @staticmethod def _fix_first_part(part, boundary_marker): @@ -154,3 +156,209 @@ def from_response(cls, response, encoding='utf-8'): content = response.content content_type = response.headers.get('content-type', None) return cls(content, content_type, encoding) + + +# This is thrown when the object is being reiterated from another place +class AlreadyIteratedException(Exception): + pass + + +# This is thrown when trying to skip to the next part without finishing to +# stream the previous one +class PreviousPartNotFinishedException(Exception): + pass + + +class StreamPart(object): + def __init__(self, headers, encoding, iterator): + self.headers = headers + self.encoding = encoding + self._iterator = iterator + self._started = False + self._consumed = False + self._content = None + self._finished = False + + def __iter__(self): + if self._started: + raise AlreadyIteratedException() + self._started = True + for typ, data in self._iterator(): + if typ == 'done' and data is False: + self._finished = True + break + elif typ == 'stream': + yield data + else: + raise ImproperBodyPartContentException() + + @property + def content(self): + if self._consumed: + return self._content + if self._started: + raise AlreadyIteratedException() + self._content = b''.join(self) + self._consumed = True + return self._content + + @property + def text(self): + return self.content.decode(self.encoding) + + +class MultipartStreamDecoder(object): + @classmethod + def from_response(cls, response, encoding='utf-8', chunk_size=10 * 1024, + header_size_limit=None): + def content(): + return response.raw.read(chunk_size) + content_type = response.headers.get('content-type', None) + return cls(content, content_type, encoding, header_size_limit) + + def __init__(self, stream_read_func, content_type, encoding='utf-8', + header_size_limit=None): + self.content_type = content_type + self.encoding = encoding + self._stream_read_func = stream_read_func + self._header_size_limit = header_size_limit + self._boundary = MultipartDecoder._find_boundary(content_type, + encoding) + self._splitter = StreamSplitter() + self._boundary = b''.join((b'--', self._boundary)) + self._boundary_split = b''.join((b'\r\n', self._boundary)) + self._state = 0 + self._found = False + self._started = False + self._finished = False + + # Call this to drain stream when error occured, or you decide + # not to read all data + def close(self): + if not self._finished: + while True: + try: + data = self._stream_read_func() + if not data: + break + # Protection if _stream_read_func is an generator + except StopIteration: + break + finally: + self._finished = True + + # The instance can be used as an context manager for automatic + # draining the stream for re-use + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, exc_traceback): + self.close() + + def __iter__(self): + if self._started: + raise AlreadyIteratedException() + self._started = True + # This is for guarding against iterating before finishing to + # iterate on the current part. + _current_stream = None + for typ, data in self._iterator(): + if _current_stream and not _current_stream._finished: + raise PreviousPartNotFinishedException() + if typ == 'headers': + _current_stream = StreamPart( + data, self.encoding, self._iterator + ) + yield _current_stream + else: + raise ImproperBodyPartContentException() + + def _iterator(self): + while True: + data = self._stream_read_func() + # This persumes that if data returned empty once it won't return + # anything again (EOS) + if not self._found and not data: + self._finished = True + break + # This part mimics the _fix_first_part logic from above + if self._state == 0: + should_be_empty_or_crlf, self._found = self._splitter.stream( + data, self._boundary, True + ) + if should_be_empty_or_crlf: + if should_be_empty_or_crlf != b'\r\n': + raise ImproperBodyPartContentException() + self._state = 1 + continue + if self._found: + self._state = 1 + continue + # Parse the headers + elif self._state == 1: + headers, self._found = self._splitter.stream(data, b'\r\n\r\n', + True) + if headers: + headers = _header_parser(headers.lstrip(), self.encoding) + headers = CaseInsensitiveDict(headers) + self._state = 2 + yield 'headers', headers + continue + # No headers found + if self._found: + headers = CaseInsensitiveDict({}) + self._state = 2 + yield 'headers', headers + continue + # This is to protect against malformed input where a header + # does not exist in a limit for performence reasons + if self._header_size_limit: + if (self._splitter.leftover_length > + self._header_size_limit): + raise ImproperBodyPartContentException() + # Stream the part + elif self._state == 2: + stream, self._found = self._splitter.stream( + data, self._boundary_split + ) + if stream: + yield 'stream', stream + # boundary_split found, end of part + if self._found: + self._state = 1 + yield 'done', False + continue + + +class StreamSplitter(object): + def __init__(self): + self.leftover = b'' + + def stream(self, data, split_data, return_only_full=False): + self.leftover += data + index = self.leftover.find(split_data) + if return_only_full: + if index > -1: + ret = self.leftover[:index] + self.leftover = self.leftover[index + len(split_data):] + found = True + else: + ret = b'' + found = False + else: + if index > -1: + ret = self.leftover[:index] + self.leftover = self.leftover[index + len(split_data):] + found = True + elif len(self.leftover) >= len(split_data): + ret = self.leftover[:-len(split_data)] + self.leftover = self.leftover[-len(split_data):] + found = False + else: + ret = b'' + found = False + return ret, found + + @property + def leftover_length(self): + return len(self.leftover) diff --git a/tests/test_multipart_decoder.py b/tests/test_multipart_decoder.py index 1862ed5c..d7c186c6 100644 --- a/tests/test_multipart_decoder.py +++ b/tests/test_multipart_decoder.py @@ -9,7 +9,9 @@ from requests_toolbelt.multipart.decoder import ( ImproperBodyPartContentException ) -from requests_toolbelt.multipart.decoder import MultipartDecoder +from requests_toolbelt.multipart.decoder import ( + MultipartDecoder, MultipartStreamDecoder +) from requests_toolbelt.multipart.decoder import ( NonMultipartContentTypeException ) @@ -104,19 +106,41 @@ def setUp(self): ) self.boundary = 'test boundary' self.encoded_1 = MultipartEncoder(self.sample_1, self.boundary) + self.encoded_1_string = self.encoded_1.to_string() self.decoded_1 = MultipartDecoder( - self.encoded_1.to_string(), + self.encoded_1_string, + self.encoded_1.content_type + ) + + def CreateMultipartStreamDecoder(self): + data = io.BytesIO(self.encoded_1_string) + + def read_function(): + return data.read(10) + + decoder = MultipartStreamDecoder( + read_function, self.encoded_1.content_type ) + + parts = [] + for part in decoder: + part.content + parts.append(part) + + return parts def test_non_multipart_response_fails(self): jpeg_response = mock.NonCallableMagicMock(spec=requests.Response) jpeg_response.headers = {'content-type': 'image/jpeg'} with pytest.raises(NonMultipartContentTypeException): MultipartDecoder.from_response(jpeg_response) + with pytest.raises(NonMultipartContentTypeException): + MultipartStreamDecoder.from_response(jpeg_response) def test_length_of_parts(self): assert len(self.sample_1) == len(self.decoded_1.parts) + assert len(self.sample_1) == len(self.CreateMultipartStreamDecoder()) def test_content_of_parts(self): def parts_equal(part, sample): @@ -124,6 +148,8 @@ def parts_equal(part, sample): parts_iter = zip(self.decoded_1.parts, self.sample_1) assert all(parts_equal(part, sample) for part, sample in parts_iter) + parts_iter = zip(self.CreateMultipartStreamDecoder(), self.sample_1) + assert all(parts_equal(part, sample) for part, sample in parts_iter) def test_header_of_parts(self): def parts_header_equal(part, sample): @@ -136,6 +162,11 @@ def parts_header_equal(part, sample): parts_header_equal(part, sample) for part, sample in parts_iter ) + parts_iter = zip(self.CreateMultipartStreamDecoder(), self.sample_1) + assert all( + parts_header_equal(part, sample) + for part, sample in parts_iter + ) def test_from_response(self): response = mock.NonCallableMagicMock(spec=requests.Response) @@ -163,6 +194,21 @@ def test_from_response(self): assert len(decoder_2.parts[1].headers) == 0 assert decoder_2.parts[1].content == b'Body 2, Line 1' + cnt.seek(0) + response.raw = cnt + decoder_2 = MultipartStreamDecoder.from_response(response) + parts = [] + for part in decoder_2: + part.content + parts.append(part) + assert decoder_2.content_type == response.headers['content-type'] + assert ( + parts[0].content == b'Body 1, Line 1\r\nBody 1, Line 2' + ) + assert parts[0].headers[b'Header-1'] == b'Header-Value-1' + assert len(parts[1].headers) == 0 + assert parts[1].content == b'Body 2, Line 1' + def test_from_responsecaplarge(self): response = mock.NonCallableMagicMock(spec=requests.Response) response.headers = { @@ -189,3 +235,17 @@ def test_from_responsecaplarge(self): assert len(decoder_2.parts[1].headers) == 0 assert decoder_2.parts[1].content == b'Body 2, Line 1' + cnt.seek(0) + response.raw = cnt + decoder_2 = MultipartStreamDecoder.from_response(response) + parts = [] + for part in decoder_2: + part.content + parts.append(part) + assert decoder_2.content_type == response.headers['content-type'] + assert ( + parts[0].content == b'Body 1, Line 1\r\nBody 1, Line 2' + ) + assert parts[0].headers[b'Header-1'] == b'Header-Value-1' + assert len(parts[1].headers) == 0 + assert parts[1].content == b'Body 2, Line 1'