Bug 792831 - WebSocket permessage compression extension, r=jduell
This commit is contained in:
@@ -56,6 +56,11 @@ import socket
|
||||
import traceback
|
||||
import zlib
|
||||
|
||||
try:
|
||||
from mod_pywebsocket import fast_masking
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def get_stack_trace():
|
||||
"""Get the current stack trace as string.
|
||||
@@ -169,45 +174,39 @@ class RepeatedXorMasker(object):
|
||||
ended and resumes from that point on the next mask method call.
|
||||
"""
|
||||
|
||||
def __init__(self, mask):
|
||||
self._mask = map(ord, mask)
|
||||
self._mask_size = len(self._mask)
|
||||
self._count = 0
|
||||
def __init__(self, masking_key):
|
||||
self._masking_key = masking_key
|
||||
self._masking_key_index = 0
|
||||
|
||||
def mask(self, s):
|
||||
def _mask_using_swig(self, s):
|
||||
masked_data = fast_masking.mask(
|
||||
s, self._masking_key, self._masking_key_index)
|
||||
self._masking_key_index = (
|
||||
(self._masking_key_index + len(s)) % len(self._masking_key))
|
||||
return masked_data
|
||||
|
||||
def _mask_using_array(self, s):
|
||||
result = array.array('B')
|
||||
result.fromstring(s)
|
||||
|
||||
# Use temporary local variables to eliminate the cost to access
|
||||
# attributes
|
||||
count = self._count
|
||||
mask = self._mask
|
||||
mask_size = self._mask_size
|
||||
masking_key = map(ord, self._masking_key)
|
||||
masking_key_size = len(masking_key)
|
||||
masking_key_index = self._masking_key_index
|
||||
|
||||
for i in xrange(len(result)):
|
||||
result[i] ^= mask[count]
|
||||
count = (count + 1) % mask_size
|
||||
self._count = count
|
||||
result[i] ^= masking_key[masking_key_index]
|
||||
masking_key_index = (masking_key_index + 1) % masking_key_size
|
||||
|
||||
self._masking_key_index = masking_key_index
|
||||
|
||||
return result.tostring()
|
||||
|
||||
|
||||
class DeflateRequest(object):
|
||||
"""A wrapper class for request object to intercept send and recv to perform
|
||||
deflate compression and decompression transparently.
|
||||
"""
|
||||
|
||||
def __init__(self, request):
|
||||
self._request = request
|
||||
self.connection = DeflateConnection(request.connection)
|
||||
|
||||
def __getattribute__(self, name):
|
||||
if name in ('_request', 'connection'):
|
||||
return object.__getattribute__(self, name)
|
||||
return self._request.__getattribute__(name)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if name in ('_request', 'connection'):
|
||||
return object.__setattr__(self, name, value)
|
||||
return self._request.__setattr__(name, value)
|
||||
if 'fast_masking' in globals():
|
||||
mask = _mask_using_swig
|
||||
else:
|
||||
mask = _mask_using_array
|
||||
|
||||
|
||||
# By making wbits option negative, we can suppress CMF/FLG (2 octet) and
|
||||
@@ -232,6 +231,12 @@ class _Deflater(object):
|
||||
self._compress = zlib.compressobj(
|
||||
zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -window_bits)
|
||||
|
||||
def compress(self, bytes):
|
||||
compressed_bytes = self._compress.compress(bytes)
|
||||
self._logger.debug('Compress input %r', bytes)
|
||||
self._logger.debug('Compress result %r', compressed_bytes)
|
||||
return compressed_bytes
|
||||
|
||||
def compress_and_flush(self, bytes):
|
||||
compressed_bytes = self._compress.compress(bytes)
|
||||
compressed_bytes += self._compress.flush(zlib.Z_SYNC_FLUSH)
|
||||
@@ -239,11 +244,19 @@ class _Deflater(object):
|
||||
self._logger.debug('Compress result %r', compressed_bytes)
|
||||
return compressed_bytes
|
||||
|
||||
def compress_and_finish(self, bytes):
|
||||
compressed_bytes = self._compress.compress(bytes)
|
||||
compressed_bytes += self._compress.flush(zlib.Z_FINISH)
|
||||
self._logger.debug('Compress input %r', bytes)
|
||||
self._logger.debug('Compress result %r', compressed_bytes)
|
||||
return compressed_bytes
|
||||
|
||||
|
||||
class _Inflater(object):
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, window_bits):
|
||||
self._logger = get_class_logger(self)
|
||||
self._window_bits = window_bits
|
||||
|
||||
self._unconsumed = ''
|
||||
|
||||
@@ -300,7 +313,7 @@ class _Inflater(object):
|
||||
|
||||
def reset(self):
|
||||
self._logger.debug('Reset')
|
||||
self._decompress = zlib.decompressobj(-zlib.MAX_WBITS)
|
||||
self._decompress = zlib.decompressobj(-self._window_bits)
|
||||
|
||||
|
||||
# Compresses/decompresses given octets using the method introduced in RFC1979.
|
||||
@@ -318,13 +331,27 @@ class _RFC1979Deflater(object):
|
||||
self._window_bits = window_bits
|
||||
self._no_context_takeover = no_context_takeover
|
||||
|
||||
def filter(self, bytes):
|
||||
if self._deflater is None or self._no_context_takeover:
|
||||
def filter(self, bytes, end=True, bfinal=False):
|
||||
if self._deflater is None:
|
||||
self._deflater = _Deflater(self._window_bits)
|
||||
|
||||
# Strip last 4 octets which is LEN and NLEN field of a non-compressed
|
||||
# block added for Z_SYNC_FLUSH.
|
||||
return self._deflater.compress_and_flush(bytes)[:-4]
|
||||
if bfinal:
|
||||
result = self._deflater.compress_and_finish(bytes)
|
||||
# Add a padding block with BFINAL = 0 and BTYPE = 0.
|
||||
result = result + chr(0)
|
||||
self._deflater = None
|
||||
return result
|
||||
|
||||
result = self._deflater.compress_and_flush(bytes)
|
||||
if end:
|
||||
# Strip last 4 octets which is LEN and NLEN field of a
|
||||
# non-compressed block added for Z_SYNC_FLUSH.
|
||||
result = result[:-4]
|
||||
|
||||
if self._no_context_takeover and end:
|
||||
self._deflater = None
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class _RFC1979Inflater(object):
|
||||
@@ -332,8 +359,8 @@ class _RFC1979Inflater(object):
|
||||
the algorithm described in the RFC1979 section 2.1.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._inflater = _Inflater()
|
||||
def __init__(self, window_bits=zlib.MAX_WBITS):
|
||||
self._inflater = _Inflater(window_bits)
|
||||
|
||||
def filter(self, bytes):
|
||||
# Restore stripped LEN and NLEN field of a non-compressed block added
|
||||
@@ -356,7 +383,7 @@ class DeflateSocket(object):
|
||||
self._logger = get_class_logger(self)
|
||||
|
||||
self._deflater = _Deflater(zlib.MAX_WBITS)
|
||||
self._inflater = _Inflater()
|
||||
self._inflater = _Inflater(zlib.MAX_WBITS)
|
||||
|
||||
def recv(self, size):
|
||||
"""Receives data from the socket specified on the construction up
|
||||
@@ -386,111 +413,4 @@ class DeflateSocket(object):
|
||||
return len(bytes)
|
||||
|
||||
|
||||
class DeflateConnection(object):
|
||||
"""A wrapper class for request object to intercept write and read to
|
||||
perform deflate compression and decompression transparently.
|
||||
"""
|
||||
|
||||
def __init__(self, connection):
|
||||
self._connection = connection
|
||||
|
||||
self._logger = get_class_logger(self)
|
||||
|
||||
self._deflater = _Deflater(zlib.MAX_WBITS)
|
||||
self._inflater = _Inflater()
|
||||
|
||||
def get_remote_addr(self):
|
||||
return self._connection.remote_addr
|
||||
remote_addr = property(get_remote_addr)
|
||||
|
||||
def put_bytes(self, bytes):
|
||||
self.write(bytes)
|
||||
|
||||
def read(self, size=-1):
|
||||
"""Reads at most size bytes. Blocks until there's at least one byte
|
||||
available.
|
||||
"""
|
||||
|
||||
# TODO(tyoshino): Allow call with size=0.
|
||||
if not (size == -1 or size > 0):
|
||||
raise Exception('size must be -1 or positive')
|
||||
|
||||
data = ''
|
||||
while True:
|
||||
if size == -1:
|
||||
data += self._inflater.decompress(-1)
|
||||
else:
|
||||
data += self._inflater.decompress(size - len(data))
|
||||
|
||||
if size >= 0 and len(data) != 0:
|
||||
break
|
||||
|
||||
# TODO(tyoshino): Make this read efficient by some workaround.
|
||||
#
|
||||
# In 3.0.3 and prior of mod_python, read blocks until length bytes
|
||||
# was read. We don't know the exact size to read while using
|
||||
# deflate, so read byte-by-byte.
|
||||
#
|
||||
# _StandaloneRequest.read that ultimately performs
|
||||
# socket._fileobject.read also blocks until length bytes was read
|
||||
read_data = self._connection.read(1)
|
||||
if not read_data:
|
||||
break
|
||||
self._inflater.append(read_data)
|
||||
return data
|
||||
|
||||
def write(self, bytes):
|
||||
self._connection.write(self._deflater.compress_and_flush(bytes))
|
||||
|
||||
|
||||
def _is_ewouldblock_errno(error_number):
|
||||
"""Returns True iff error_number indicates that receive operation would
|
||||
block. To make this portable, we check availability of errno and then
|
||||
compare them.
|
||||
"""
|
||||
|
||||
for error_name in ['WSAEWOULDBLOCK', 'EWOULDBLOCK', 'EAGAIN']:
|
||||
if (error_name in dir(errno) and
|
||||
error_number == getattr(errno, error_name)):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def drain_received_data(raw_socket):
|
||||
# Set the socket non-blocking.
|
||||
original_timeout = raw_socket.gettimeout()
|
||||
raw_socket.settimeout(0.0)
|
||||
|
||||
drained_data = []
|
||||
|
||||
# Drain until the socket is closed or no data is immediately
|
||||
# available for read.
|
||||
while True:
|
||||
try:
|
||||
data = raw_socket.recv(1)
|
||||
if not data:
|
||||
break
|
||||
drained_data.append(data)
|
||||
except socket.error, e:
|
||||
# e can be either a pair (errno, string) or just a string (or
|
||||
# something else) telling what went wrong. We suppress only
|
||||
# the errors that indicates that the socket blocks. Those
|
||||
# exceptions can be parsed as a pair (errno, string).
|
||||
try:
|
||||
error_number, message = e
|
||||
except:
|
||||
# Failed to parse socket.error.
|
||||
raise e
|
||||
|
||||
if _is_ewouldblock_errno(error_number):
|
||||
break
|
||||
else:
|
||||
raise e
|
||||
|
||||
# Rollback timeout value.
|
||||
raw_socket.settimeout(original_timeout)
|
||||
|
||||
return ''.join(drained_data)
|
||||
|
||||
|
||||
# vi:sts=4 sw=4 et
|
||||
|
||||
Reference in New Issue
Block a user