Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion requests_toolbelt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
from .auth.guess import GuessAuth
from .multipart import (
MultipartEncoder, MultipartEncoderMonitor, MultipartDecoder,
ImproperBodyPartContentException, NonMultipartContentTypeException
ImproperBodyPartContentException, NonMultipartContentTypeException,
MultipartStreamDecoder
)
from .streaming_iterator import StreamingIterator
from .utils.user_agent import user_agent
Expand All @@ -31,4 +32,5 @@
'StreamingIterator', 'user_agent', 'ImproperBodyPartContentException',
'NonMultipartContentTypeException', '__title__', '__authors__',
'__license__', '__copyright__', '__version__', '__version_info__',
'MultipartStreamDecoder'
]
3 changes: 2 additions & 1 deletion requests_toolbelt/multipart/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"""

from .encoder import MultipartEncoder, MultipartEncoderMonitor
from .decoder import MultipartDecoder
from .decoder import MultipartDecoder, MultipartStreamDecoder
from .decoder import ImproperBodyPartContentException
from .decoder import NonMultipartContentTypeException

Expand All @@ -22,6 +22,7 @@
'MultipartEncoder',
'MultipartEncoderMonitor',
'MultipartDecoder',
'MultipartStreamDecoder',
'ImproperBodyPartContentException',
'NonMultipartContentTypeException',
'__title__',
Expand Down
216 changes: 212 additions & 4 deletions requests_toolbelt/multipart/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,12 @@ def __init__(self, content, content_type, encoding='utf-8'):
self.encoding = encoding
#: Parsed parts of the multipart response body
self.parts = tuple()
self._find_boundary()
self.boundary = MultipartDecoder._find_boundary(content_type, encoding)
self._parse_body(content)

def _find_boundary(self):
ct_info = tuple(x.strip() for x in self.content_type.split(';'))
@staticmethod
def _find_boundary(content_type, encoding):
ct_info = tuple(x.strip() for x in content_type.split(';'))
mimetype = ct_info[0]
if mimetype.split('/')[0].lower() != 'multipart':
raise NonMultipartContentTypeException(
Expand All @@ -123,7 +124,8 @@ def _find_boundary(self):
'='
)
if attr.lower() == 'boundary':
self.boundary = encode_with(value.strip('"'), self.encoding)
boundary = encode_with(value.strip('"'), encoding)
return boundary

@staticmethod
def _fix_first_part(part, boundary_marker):
Expand Down Expand Up @@ -154,3 +156,209 @@ def from_response(cls, response, encoding='utf-8'):
content = response.content
content_type = response.headers.get('content-type', None)
return cls(content, content_type, encoding)


# This is thrown when the object is being reiterated from another place
class AlreadyIteratedException(Exception):
pass


# This is thrown when trying to skip to the next part without finishing to
# stream the previous one
class PreviousPartNotFinishedException(Exception):
pass


class StreamPart(object):
def __init__(self, headers, encoding, iterator):
self.headers = headers
self.encoding = encoding
self._iterator = iterator
self._started = False
self._consumed = False
self._content = None
self._finished = False

def __iter__(self):
if self._started:
raise AlreadyIteratedException()
self._started = True
for typ, data in self._iterator():
if typ == 'done' and data is False:
self._finished = True
break
elif typ == 'stream':
yield data
else:
raise ImproperBodyPartContentException()

@property
def content(self):
if self._consumed:
return self._content
if self._started:
raise AlreadyIteratedException()
self._content = b''.join(self)
self._consumed = True
return self._content

@property
def text(self):
return self.content.decode(self.encoding)


class MultipartStreamDecoder(object):
@classmethod
def from_response(cls, response, encoding='utf-8', chunk_size=10 * 1024,
header_size_limit=None):
def content():
return response.raw.read(chunk_size)
content_type = response.headers.get('content-type', None)
return cls(content, content_type, encoding, header_size_limit)

def __init__(self, stream_read_func, content_type, encoding='utf-8',
header_size_limit=None):
self.content_type = content_type
self.encoding = encoding
self._stream_read_func = stream_read_func
self._header_size_limit = header_size_limit
self._boundary = MultipartDecoder._find_boundary(content_type,
encoding)
self._splitter = StreamSplitter()
self._boundary = b''.join((b'--', self._boundary))
self._boundary_split = b''.join((b'\r\n', self._boundary))
self._state = 0
self._found = False
self._started = False
self._finished = False

# Call this to drain stream when error occured, or you decide
# not to read all data
def close(self):
if not self._finished:
while True:
try:
data = self._stream_read_func()
if not data:
break
# Protection if _stream_read_func is an generator
except StopIteration:
break
finally:
self._finished = True

# The instance can be used as an context manager for automatic
# draining the stream for re-use
def __enter__(self):
return self

def __exit__(self, exc_type, exc_value, exc_traceback):
self.close()

def __iter__(self):
if self._started:
raise AlreadyIteratedException()
self._started = True
# This is for guarding against iterating before finishing to
# iterate on the current part.
_current_stream = None
for typ, data in self._iterator():
if _current_stream and not _current_stream._finished:
raise PreviousPartNotFinishedException()
if typ == 'headers':
_current_stream = StreamPart(
data, self.encoding, self._iterator
)
yield _current_stream
else:
raise ImproperBodyPartContentException()

def _iterator(self):
while True:
data = self._stream_read_func()
# This persumes that if data returned empty once it won't return
# anything again (EOS)
if not self._found and not data:
self._finished = True
break
# This part mimics the _fix_first_part logic from above
if self._state == 0:
should_be_empty_or_crlf, self._found = self._splitter.stream(
data, self._boundary, True
)
if should_be_empty_or_crlf:
if should_be_empty_or_crlf != b'\r\n':
raise ImproperBodyPartContentException()
self._state = 1
continue
if self._found:
self._state = 1
continue
# Parse the headers
elif self._state == 1:
headers, self._found = self._splitter.stream(data, b'\r\n\r\n',
True)
if headers:
headers = _header_parser(headers.lstrip(), self.encoding)
headers = CaseInsensitiveDict(headers)
self._state = 2
yield 'headers', headers
continue
# No headers found
if self._found:
headers = CaseInsensitiveDict({})
self._state = 2
yield 'headers', headers
continue
# This is to protect against malformed input where a header
# does not exist in a limit for performence reasons
if self._header_size_limit:
if (self._splitter.leftover_length >
self._header_size_limit):
raise ImproperBodyPartContentException()
# Stream the part
elif self._state == 2:
stream, self._found = self._splitter.stream(
data, self._boundary_split
)
if stream:
yield 'stream', stream
# boundary_split found, end of part
if self._found:
self._state = 1
yield 'done', False
continue


class StreamSplitter(object):
def __init__(self):
self.leftover = b''

def stream(self, data, split_data, return_only_full=False):
self.leftover += data
index = self.leftover.find(split_data)
if return_only_full:
if index > -1:
ret = self.leftover[:index]
self.leftover = self.leftover[index + len(split_data):]
found = True
else:
ret = b''
found = False
else:
if index > -1:
ret = self.leftover[:index]
self.leftover = self.leftover[index + len(split_data):]
found = True
elif len(self.leftover) >= len(split_data):
ret = self.leftover[:-len(split_data)]
self.leftover = self.leftover[-len(split_data):]
found = False
else:
ret = b''
found = False
return ret, found

@property
def leftover_length(self):
return len(self.leftover)
64 changes: 62 additions & 2 deletions tests/test_multipart_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
from requests_toolbelt.multipart.decoder import (
ImproperBodyPartContentException
)
from requests_toolbelt.multipart.decoder import MultipartDecoder
from requests_toolbelt.multipart.decoder import (
MultipartDecoder, MultipartStreamDecoder
)
from requests_toolbelt.multipart.decoder import (
NonMultipartContentTypeException
)
Expand Down Expand Up @@ -104,26 +106,50 @@ def setUp(self):
)
self.boundary = 'test boundary'
self.encoded_1 = MultipartEncoder(self.sample_1, self.boundary)
self.encoded_1_string = self.encoded_1.to_string()
self.decoded_1 = MultipartDecoder(
self.encoded_1.to_string(),
self.encoded_1_string,
self.encoded_1.content_type
)

def CreateMultipartStreamDecoder(self):
data = io.BytesIO(self.encoded_1_string)

def read_function():
return data.read(10)

decoder = MultipartStreamDecoder(
read_function,
self.encoded_1.content_type
)

parts = []
for part in decoder:
part.content
parts.append(part)

return parts

def test_non_multipart_response_fails(self):
jpeg_response = mock.NonCallableMagicMock(spec=requests.Response)
jpeg_response.headers = {'content-type': 'image/jpeg'}
with pytest.raises(NonMultipartContentTypeException):
MultipartDecoder.from_response(jpeg_response)
with pytest.raises(NonMultipartContentTypeException):
MultipartStreamDecoder.from_response(jpeg_response)

def test_length_of_parts(self):
assert len(self.sample_1) == len(self.decoded_1.parts)
assert len(self.sample_1) == len(self.CreateMultipartStreamDecoder())

def test_content_of_parts(self):
def parts_equal(part, sample):
return part.content == encode_with(sample[1], 'utf-8')

parts_iter = zip(self.decoded_1.parts, self.sample_1)
assert all(parts_equal(part, sample) for part, sample in parts_iter)
parts_iter = zip(self.CreateMultipartStreamDecoder(), self.sample_1)
assert all(parts_equal(part, sample) for part, sample in parts_iter)

def test_header_of_parts(self):
def parts_header_equal(part, sample):
Expand All @@ -136,6 +162,11 @@ def parts_header_equal(part, sample):
parts_header_equal(part, sample)
for part, sample in parts_iter
)
parts_iter = zip(self.CreateMultipartStreamDecoder(), self.sample_1)
assert all(
parts_header_equal(part, sample)
for part, sample in parts_iter
)

def test_from_response(self):
response = mock.NonCallableMagicMock(spec=requests.Response)
Expand Down Expand Up @@ -163,6 +194,21 @@ def test_from_response(self):
assert len(decoder_2.parts[1].headers) == 0
assert decoder_2.parts[1].content == b'Body 2, Line 1'

cnt.seek(0)
response.raw = cnt
decoder_2 = MultipartStreamDecoder.from_response(response)
parts = []
for part in decoder_2:
part.content
parts.append(part)
assert decoder_2.content_type == response.headers['content-type']
assert (
parts[0].content == b'Body 1, Line 1\r\nBody 1, Line 2'
)
assert parts[0].headers[b'Header-1'] == b'Header-Value-1'
assert len(parts[1].headers) == 0
assert parts[1].content == b'Body 2, Line 1'

def test_from_responsecaplarge(self):
response = mock.NonCallableMagicMock(spec=requests.Response)
response.headers = {
Expand All @@ -189,3 +235,17 @@ def test_from_responsecaplarge(self):
assert len(decoder_2.parts[1].headers) == 0
assert decoder_2.parts[1].content == b'Body 2, Line 1'

cnt.seek(0)
response.raw = cnt
decoder_2 = MultipartStreamDecoder.from_response(response)
parts = []
for part in decoder_2:
part.content
parts.append(part)
assert decoder_2.content_type == response.headers['content-type']
assert (
parts[0].content == b'Body 1, Line 1\r\nBody 1, Line 2'
)
assert parts[0].headers[b'Header-1'] == b'Header-Value-1'
assert len(parts[1].headers) == 0
assert parts[1].content == b'Body 2, Line 1'