Bug 792831 - WebSocket permessage compression extension, r=jduell

This commit is contained in:
Michal Novotny
2014-12-29 12:43:40 +01:00
parent ff9aeb2409
commit 4c7294112e
32 changed files with 4631 additions and 1381 deletions

View File

@@ -56,6 +56,11 @@ import socket
import traceback
import zlib
try:
from mod_pywebsocket import fast_masking
except ImportError:
pass
def get_stack_trace():
"""Get the current stack trace as string.
@@ -169,45 +174,39 @@ class RepeatedXorMasker(object):
ended and resumes from that point on the next mask method call.
"""
def __init__(self, mask):
self._mask = map(ord, mask)
self._mask_size = len(self._mask)
self._count = 0
def __init__(self, masking_key):
self._masking_key = masking_key
self._masking_key_index = 0
def mask(self, s):
def _mask_using_swig(self, s):
masked_data = fast_masking.mask(
s, self._masking_key, self._masking_key_index)
self._masking_key_index = (
(self._masking_key_index + len(s)) % len(self._masking_key))
return masked_data
def _mask_using_array(self, s):
result = array.array('B')
result.fromstring(s)
# Use temporary local variables to eliminate the cost to access
# attributes
count = self._count
mask = self._mask
mask_size = self._mask_size
masking_key = map(ord, self._masking_key)
masking_key_size = len(masking_key)
masking_key_index = self._masking_key_index
for i in xrange(len(result)):
result[i] ^= mask[count]
count = (count + 1) % mask_size
self._count = count
result[i] ^= masking_key[masking_key_index]
masking_key_index = (masking_key_index + 1) % masking_key_size
self._masking_key_index = masking_key_index
return result.tostring()
class DeflateRequest(object):
"""A wrapper class for request object to intercept send and recv to perform
deflate compression and decompression transparently.
"""
def __init__(self, request):
self._request = request
self.connection = DeflateConnection(request.connection)
def __getattribute__(self, name):
if name in ('_request', 'connection'):
return object.__getattribute__(self, name)
return self._request.__getattribute__(name)
def __setattr__(self, name, value):
if name in ('_request', 'connection'):
return object.__setattr__(self, name, value)
return self._request.__setattr__(name, value)
if 'fast_masking' in globals():
mask = _mask_using_swig
else:
mask = _mask_using_array
# By making wbits option negative, we can suppress CMF/FLG (2 octet) and
@@ -232,6 +231,12 @@ class _Deflater(object):
self._compress = zlib.compressobj(
zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -window_bits)
def compress(self, bytes):
compressed_bytes = self._compress.compress(bytes)
self._logger.debug('Compress input %r', bytes)
self._logger.debug('Compress result %r', compressed_bytes)
return compressed_bytes
def compress_and_flush(self, bytes):
compressed_bytes = self._compress.compress(bytes)
compressed_bytes += self._compress.flush(zlib.Z_SYNC_FLUSH)
@@ -239,11 +244,19 @@ class _Deflater(object):
self._logger.debug('Compress result %r', compressed_bytes)
return compressed_bytes
def compress_and_finish(self, bytes):
compressed_bytes = self._compress.compress(bytes)
compressed_bytes += self._compress.flush(zlib.Z_FINISH)
self._logger.debug('Compress input %r', bytes)
self._logger.debug('Compress result %r', compressed_bytes)
return compressed_bytes
class _Inflater(object):
def __init__(self):
def __init__(self, window_bits):
self._logger = get_class_logger(self)
self._window_bits = window_bits
self._unconsumed = ''
@@ -300,7 +313,7 @@ class _Inflater(object):
def reset(self):
self._logger.debug('Reset')
self._decompress = zlib.decompressobj(-zlib.MAX_WBITS)
self._decompress = zlib.decompressobj(-self._window_bits)
# Compresses/decompresses given octets using the method introduced in RFC1979.
@@ -318,13 +331,27 @@ class _RFC1979Deflater(object):
self._window_bits = window_bits
self._no_context_takeover = no_context_takeover
def filter(self, bytes):
if self._deflater is None or self._no_context_takeover:
def filter(self, bytes, end=True, bfinal=False):
if self._deflater is None:
self._deflater = _Deflater(self._window_bits)
# Strip last 4 octets which is LEN and NLEN field of a non-compressed
# block added for Z_SYNC_FLUSH.
return self._deflater.compress_and_flush(bytes)[:-4]
if bfinal:
result = self._deflater.compress_and_finish(bytes)
# Add a padding block with BFINAL = 0 and BTYPE = 0.
result = result + chr(0)
self._deflater = None
return result
result = self._deflater.compress_and_flush(bytes)
if end:
# Strip last 4 octets which is LEN and NLEN field of a
# non-compressed block added for Z_SYNC_FLUSH.
result = result[:-4]
if self._no_context_takeover and end:
self._deflater = None
return result
class _RFC1979Inflater(object):
@@ -332,8 +359,8 @@ class _RFC1979Inflater(object):
the algorithm described in the RFC1979 section 2.1.
"""
def __init__(self):
self._inflater = _Inflater()
def __init__(self, window_bits=zlib.MAX_WBITS):
self._inflater = _Inflater(window_bits)
def filter(self, bytes):
# Restore stripped LEN and NLEN field of a non-compressed block added
@@ -356,7 +383,7 @@ class DeflateSocket(object):
self._logger = get_class_logger(self)
self._deflater = _Deflater(zlib.MAX_WBITS)
self._inflater = _Inflater()
self._inflater = _Inflater(zlib.MAX_WBITS)
def recv(self, size):
"""Receives data from the socket specified on the construction up
@@ -386,111 +413,4 @@ class DeflateSocket(object):
return len(bytes)
class DeflateConnection(object):
"""A wrapper class for request object to intercept write and read to
perform deflate compression and decompression transparently.
"""
def __init__(self, connection):
self._connection = connection
self._logger = get_class_logger(self)
self._deflater = _Deflater(zlib.MAX_WBITS)
self._inflater = _Inflater()
def get_remote_addr(self):
return self._connection.remote_addr
remote_addr = property(get_remote_addr)
def put_bytes(self, bytes):
self.write(bytes)
def read(self, size=-1):
"""Reads at most size bytes. Blocks until there's at least one byte
available.
"""
# TODO(tyoshino): Allow call with size=0.
if not (size == -1 or size > 0):
raise Exception('size must be -1 or positive')
data = ''
while True:
if size == -1:
data += self._inflater.decompress(-1)
else:
data += self._inflater.decompress(size - len(data))
if size >= 0 and len(data) != 0:
break
# TODO(tyoshino): Make this read efficient by some workaround.
#
# In 3.0.3 and prior of mod_python, read blocks until length bytes
# was read. We don't know the exact size to read while using
# deflate, so read byte-by-byte.
#
# _StandaloneRequest.read that ultimately performs
# socket._fileobject.read also blocks until length bytes was read
read_data = self._connection.read(1)
if not read_data:
break
self._inflater.append(read_data)
return data
def write(self, bytes):
self._connection.write(self._deflater.compress_and_flush(bytes))
def _is_ewouldblock_errno(error_number):
"""Returns True iff error_number indicates that receive operation would
block. To make this portable, we check availability of errno and then
compare them.
"""
for error_name in ['WSAEWOULDBLOCK', 'EWOULDBLOCK', 'EAGAIN']:
if (error_name in dir(errno) and
error_number == getattr(errno, error_name)):
return True
return False
def drain_received_data(raw_socket):
# Set the socket non-blocking.
original_timeout = raw_socket.gettimeout()
raw_socket.settimeout(0.0)
drained_data = []
# Drain until the socket is closed or no data is immediately
# available for read.
while True:
try:
data = raw_socket.recv(1)
if not data:
break
drained_data.append(data)
except socket.error, e:
# e can be either a pair (errno, string) or just a string (or
# something else) telling what went wrong. We suppress only
# the errors that indicates that the socket blocks. Those
# exceptions can be parsed as a pair (errno, string).
try:
error_number, message = e
except:
# Failed to parse socket.error.
raise e
if _is_ewouldblock_errno(error_number):
break
else:
raise e
# Rollback timeout value.
raw_socket.settimeout(original_timeout)
return ''.join(drained_data)
# vi:sts=4 sw=4 et