gorealis v2 refactor (#5)

* Changing default timeout for start maintenance.

* Upgrading dependencies to gorealis v2 and thrift  0.12.0

* Refactored to update to gorealis v2.
This commit is contained in:
Renan DelValle 2018-12-27 11:31:51 -08:00 committed by GitHub
parent ad4dd9606e
commit 6ab5c9334d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
1335 changed files with 137431 additions and 61530 deletions

1
vendor/git.apache.org/thrift.git/lib/py/MANIFEST.in generated vendored Normal file
View file

@ -0,0 +1 @@
include src/ext/*

View file

@ -49,6 +49,7 @@ check-local: all py3-test
EXTRA_DIST = \
CMakeLists.txt \
MANIFEST.in \
coding_standards.md \
compat \
setup.py \

View file

@ -22,7 +22,7 @@
import sys
try:
from setuptools import setup, Extension
except:
except Exception:
from distutils.core import setup, Extension
from distutils.command.build_ext import build_ext
@ -31,7 +31,10 @@ from distutils.errors import CCompilerError, DistutilsExecError, DistutilsPlatfo
# Fix to build sdist under vagrant
import os
if 'vagrant' in str(os.environ):
del os.link
try:
del os.link
except AttributeError:
pass
include_dirs = ['src']
if sys.platform == 'win32':
@ -87,7 +90,7 @@ def run_setup(with_binary):
twisted_deps = ['twisted']
setup(name='thrift',
version='0.10.0',
version='0.12.0',
description='Python bindings for the Apache Thrift RPC system',
author='Thrift Developers',
author_email='dev@thrift.apache.org',
@ -117,9 +120,11 @@ def run_setup(with_binary):
'Topic :: Software Development :: Libraries',
'Topic :: System :: Networking'
],
zip_safe=False,
**extensions
)
try:
with_binary = True
run_setup(with_binary)

View file

@ -31,11 +31,11 @@ class TMultiplexedProcessor(TProcessor):
def process(self, iprot, oprot):
(name, type, seqid) = iprot.readMessageBegin()
if type != TMessageType.CALL and type != TMessageType.ONEWAY:
raise TException("TMultiplex protocol only supports CALL & ONEWAY")
raise TException("TMultiplexed protocol only supports CALL & ONEWAY")
index = name.find(TMultiplexedProtocol.SEPARATOR)
if index < 0:
raise TException("Service name not found in message name: " + name + ". Did you forget to use TMultiplexProtocol in your client?")
raise TException("Service name not found in message name: " + name + ". Did you forget to use TMultiplexedProtocol in your client?")
serviceName = name[0:index]
call = name[index + len(TMultiplexedProtocol.SEPARATOR):]
@ -48,7 +48,6 @@ class TMultiplexedProcessor(TProcessor):
class StoredMessageProtocol(TProtocolDecorator.TProtocolDecorator):
def __init__(self, protocol, messageBegin):
TProtocolDecorator.TProtocolDecorator.__init__(self, protocol)
self.messageBegin = messageBegin
def readMessageBegin(self):

View file

@ -0,0 +1,83 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from thrift.Thrift import TType
TYPE_IDX = 1
SPEC_ARGS_IDX = 3
SPEC_ARGS_CLASS_REF_IDX = 0
SPEC_ARGS_THRIFT_SPEC_IDX = 1
def fix_spec(all_structs):
"""Wire up recursive references for all TStruct definitions inside of each thrift_spec."""
for struc in all_structs:
spec = struc.thrift_spec
for thrift_spec in spec:
if thrift_spec is None:
continue
elif thrift_spec[TYPE_IDX] == TType.STRUCT:
other = thrift_spec[SPEC_ARGS_IDX][SPEC_ARGS_CLASS_REF_IDX].thrift_spec
thrift_spec[SPEC_ARGS_IDX][SPEC_ARGS_THRIFT_SPEC_IDX] = other
elif thrift_spec[TYPE_IDX] in (TType.LIST, TType.SET):
_fix_list_or_set(thrift_spec[SPEC_ARGS_IDX])
elif thrift_spec[TYPE_IDX] == TType.MAP:
_fix_map(thrift_spec[SPEC_ARGS_IDX])
def _fix_list_or_set(element_type):
# For a list or set, the thrift_spec entry looks like,
# (1, TType.LIST, 'lister', (TType.STRUCT, [RecList, None], False), None, ), # 1
# so ``element_type`` will be,
# (TType.STRUCT, [RecList, None], False)
if element_type[0] == TType.STRUCT:
element_type[1][1] = element_type[1][0].thrift_spec
elif element_type[0] in (TType.LIST, TType.SET):
_fix_list_or_set(element_type[1])
elif element_type[0] == TType.MAP:
_fix_map(element_type[1])
def _fix_map(element_type):
# For a map of key -> value type, ``element_type`` will be,
# (TType.I16, None, TType.STRUCT, [RecMapBasic, None], False), None, )
# which is just a normal struct definition.
#
# For a map of key -> list / set, ``element_type`` will be,
# (TType.I16, None, TType.LIST, (TType.STRUCT, [RecMapList, None], False), False)
# and we need to process the 3rd element as a list.
#
# For a map of key -> map, ``element_type`` will be,
# (TType.I16, None, TType.MAP, (TType.I16, None, TType.STRUCT,
# [RecMapMap, None], False), False)
# and need to process 3rd element as a map.
# Is the map key a struct?
if element_type[0] == TType.STRUCT:
element_type[1][1] = element_type[1][0].thrift_spec
elif element_type[0] in (TType.LIST, TType.SET):
_fix_list_or_set(element_type[1])
elif element_type[0] == TType.MAP:
_fix_map(element_type[1])
# Is the map value a struct?
if element_type[2] == TType.STRUCT:
element_type[3][1] = element_type[3][0].thrift_spec
elif element_type[2] in (TType.LIST, TType.SET):
_fix_list_or_set(element_type[3])
elif element_type[2] == TType.MAP:
_fix_map(element_type[3])

View file

@ -69,9 +69,9 @@ class TMessageType(object):
class TProcessor(object):
"""Base class for procsessor, which works on two streams."""
"""Base class for processor, which works on two streams."""
def process(iprot, oprot):
def process(self, iprot, oprot):
pass

View file

@ -29,6 +29,9 @@ if sys.version_info[0] == 2:
def str_to_binary(str_val):
return str_val
def byte_index(bytes_val, i):
return ord(bytes_val[i])
else:
from io import BytesIO as BufferIO # noqa
@ -38,3 +41,6 @@ else:
def str_to_binary(str_val):
return bytes(str_val, 'utf8')
def byte_index(bytes_val, i):
return bytes_val[i]

View file

@ -113,7 +113,8 @@ public:
if (!readBytes(&buf, sizeof(int16_t))) {
return false;
}
val = static_cast<int16_t>(ntohs(*reinterpret_cast<int16_t*>(buf)));
memcpy(&val, buf, sizeof(int16_t));
val = ntohs(val);
return true;
}
@ -122,7 +123,8 @@ public:
if (!readBytes(&buf, sizeof(int32_t))) {
return false;
}
val = static_cast<int32_t>(ntohl(*reinterpret_cast<int32_t*>(buf)));
memcpy(&val, buf, sizeof(int32_t));
val = ntohl(val);
return true;
}
@ -131,7 +133,8 @@ public:
if (!readBytes(&buf, sizeof(int64_t))) {
return false;
}
val = static_cast<int64_t>(ntohll(*reinterpret_cast<int64_t*>(buf)));
memcpy(&val, buf, sizeof(int64_t));
val = ntohll(val);
return true;
}

View file

@ -162,7 +162,8 @@ public:
if (!readBytes(&buf, 8)) {
return false;
}
transfer.f = letohll(*reinterpret_cast<int64_t*>(buf));
memcpy(&transfer.f, buf, sizeof(int64_t));
transfer.f = letohll(transfer.f);
val = transfer.t;
return true;
}

View file

@ -87,12 +87,7 @@ static PyObject* decode_impl(PyObject* args) {
}
T protocol;
#ifdef _MSC_VER
// workaround strange VC++ 2015 bug where #else path does not compile
int32_t default_limit = INT32_MAX;
#else
int32_t default_limit = std::numeric_limits<int32_t>::max();
#endif
int32_t default_limit = (std::numeric_limits<int32_t>::max)();
protocol.setStringLengthLimit(
as_long_then_delete(PyObject_GetAttr(oprot, INTERN_STRING(string_length_limit)),
default_limit));

View file

@ -33,8 +33,8 @@ class ProtocolBase {
public:
ProtocolBase()
: stringLimit_(std::numeric_limits<int32_t>::max()),
containerLimit_(std::numeric_limits<int32_t>::max()),
: stringLimit_((std::numeric_limits<int32_t>::max)()),
containerLimit_((std::numeric_limits<int32_t>::max)()),
output_(NULL) {}
inline virtual ~ProtocolBase();

View file

@ -102,7 +102,7 @@ inline bool ProtocolBase<Impl>::writeBuffer(char* data, size_t size) {
PyErr_SetString(PyExc_IOError, "failed to write to cStringIO object");
return false;
}
if (len != size) {
if (static_cast<size_t>(len) != size) {
PyErr_Format(PyExc_EOFError, "write length mismatch: expected %lu got %d", size, len);
return false;
}
@ -144,7 +144,7 @@ inline int read_buffer(PyObject* buf, char** output, int len) {
*output = PyBytes_AS_STRING(buf2->buf) + buf2->pos;
#endif
Py_ssize_t pos0 = buf2->pos;
buf2->pos = std::min(buf2->pos + static_cast<Py_ssize_t>(len), buf2->string_size);
buf2->pos = (std::min)(buf2->pos + static_cast<Py_ssize_t>(len), buf2->string_size);
return static_cast<int>(buf2->pos - pos0);
}
}
@ -212,7 +212,7 @@ inline bool check_ssize_t_32(Py_ssize_t len) {
if (INT_CONV_ERROR_OCCURRED(len)) {
return false;
}
if (!CHECK_RANGE(len, 0, std::numeric_limits<int32_t>::max())) {
if (!CHECK_RANGE(len, 0, (std::numeric_limits<int32_t>::max)())) {
PyErr_SetString(PyExc_OverflowError, "size out of range: exceeded INT32_MAX");
return false;
}
@ -360,8 +360,8 @@ bool ProtocolBase<Impl>::encodeValue(PyObject* value, TType type, PyObject* type
case T_I08: {
int8_t val;
if (!parse_pyint(value, &val, std::numeric_limits<int8_t>::min(),
std::numeric_limits<int8_t>::max())) {
if (!parse_pyint(value, &val, (std::numeric_limits<int8_t>::min)(),
(std::numeric_limits<int8_t>::max)())) {
return false;
}
@ -371,8 +371,8 @@ bool ProtocolBase<Impl>::encodeValue(PyObject* value, TType type, PyObject* type
case T_I16: {
int16_t val;
if (!parse_pyint(value, &val, std::numeric_limits<int16_t>::min(),
std::numeric_limits<int16_t>::max())) {
if (!parse_pyint(value, &val, (std::numeric_limits<int16_t>::min)(),
(std::numeric_limits<int16_t>::max)())) {
return false;
}
@ -382,8 +382,8 @@ bool ProtocolBase<Impl>::encodeValue(PyObject* value, TType type, PyObject* type
case T_I32: {
int32_t val;
if (!parse_pyint(value, &val, std::numeric_limits<int32_t>::min(),
std::numeric_limits<int32_t>::max())) {
if (!parse_pyint(value, &val, (std::numeric_limits<int32_t>::min)(),
(std::numeric_limits<int32_t>::max)())) {
return false;
}
@ -397,8 +397,8 @@ bool ProtocolBase<Impl>::encodeValue(PyObject* value, TType type, PyObject* type
return false;
}
if (!CHECK_RANGE(nval, std::numeric_limits<int64_t>::min(),
std::numeric_limits<int64_t>::max())) {
if (!CHECK_RANGE(nval, (std::numeric_limits<int64_t>::min)(),
(std::numeric_limits<int64_t>::max)())) {
PyErr_SetString(PyExc_OverflowError, "int out of range");
return false;
}

View file

@ -98,13 +98,13 @@ bool parse_map_args(MapTypeArgs* dest, PyObject* typeargs) {
}
bool parse_struct_args(StructTypeArgs* dest, PyObject* typeargs) {
if (PyTuple_Size(typeargs) != 2) {
PyErr_SetString(PyExc_TypeError, "expecting tuple of size 2 for struct args");
if (PyList_Size(typeargs) != 2) {
PyErr_SetString(PyExc_TypeError, "expecting list of size 2 for struct args");
return false;
}
dest->klass = PyTuple_GET_ITEM(typeargs, 0);
dest->spec = PyTuple_GET_ITEM(typeargs, 1);
dest->klass = PyList_GET_ITEM(typeargs, 0);
dest->spec = PyList_GET_ITEM(typeargs, 1);
return true;
}

View file

@ -23,6 +23,7 @@
#include <Python.h>
#ifdef _MSC_VER
#define __STDC_FORMAT_MACROS
#define __STDC_LIMIT_MACROS
#endif
#include <stdint.h>

View file

@ -44,14 +44,14 @@ class TBase(object):
if (iprot._fast_decode is not None and
isinstance(iprot.trans, TTransport.CReadableTransport) and
self.thrift_spec is not None):
iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec))
iprot._fast_decode(self, iprot, [self.__class__, self.thrift_spec])
else:
iprot.readStruct(self, self.thrift_spec)
def write(self, oprot):
if (oprot._fast_encode is not None and self.thrift_spec is not None):
oprot.trans.write(
oprot._fast_encode(self, (self.__class__, self.thrift_spec)))
oprot._fast_encode(self, [self.__class__, self.thrift_spec]))
else:
oprot.writeStruct(self, self.thrift_spec)
@ -77,6 +77,6 @@ class TFrozenBase(TBase):
cls.thrift_spec is not None):
self = cls()
return iprot._fast_decode(None, iprot,
(self.__class__, self.thrift_spec))
[self.__class__, self.thrift_spec])
else:
return iprot.readStruct(cls, cls.thrift_spec, True)

View file

@ -42,6 +42,8 @@ def make_helper(v_from, container):
return func(self, *args, **kwargs)
return nested
return helper
writer = make_helper(VALUE_WRITE, CONTAINER_WRITE)
reader = make_helper(VALUE_READ, CONTAINER_READ)
@ -94,6 +96,7 @@ class CompactType(object):
MAP = 0x0B
STRUCT = 0x0C
CTYPES = {
TType.STOP: CompactType.STOP,
TType.BOOL: CompactType.TRUE, # used for collection

View file

@ -0,0 +1,225 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
from thrift.protocol.TBinaryProtocol import TBinaryProtocolAccelerated
from thrift.protocol.TCompactProtocol import TCompactProtocolAccelerated
from thrift.protocol.TProtocol import TProtocolBase, TProtocolException
from thrift.Thrift import TApplicationException, TMessageType
from thrift.transport.THeaderTransport import THeaderTransport, THeaderSubprotocolID, THeaderClientType
PROTOCOLS_BY_ID = {
THeaderSubprotocolID.BINARY: TBinaryProtocolAccelerated,
THeaderSubprotocolID.COMPACT: TCompactProtocolAccelerated,
}
class THeaderProtocol(TProtocolBase):
"""A framed protocol with headers and payload transforms.
THeaderProtocol frames other Thrift protocols and adds support for optional
out-of-band headers. The currently supported subprotocols are
TBinaryProtocol and TCompactProtocol.
It's also possible to apply transforms to the encoded message payload. The
only transform currently supported is to gzip.
When used in a server, THeaderProtocol can accept messages from
non-THeaderProtocol clients if allowed (see `allowed_client_types`). This
includes framed and unframed transports and both TBinaryProtocol and
TCompactProtocol. The server will respond in the appropriate dialect for
the connected client. HTTP clients are not currently supported.
THeaderProtocol does not currently support THTTPServer, TNonblockingServer,
or TProcessPoolServer.
See doc/specs/HeaderFormat.md for details of the wire format.
"""
def __init__(self, transport, allowed_client_types):
# much of the actual work for THeaderProtocol happens down in
# THeaderTransport since we need to do low-level shenanigans to detect
# if the client is sending us headers or one of the headerless formats
# we support. this wraps the real transport with the one that does all
# the magic.
if not isinstance(transport, THeaderTransport):
transport = THeaderTransport(transport, allowed_client_types)
super(THeaderProtocol, self).__init__(transport)
self._set_protocol()
def get_headers(self):
return self.trans.get_headers()
def set_header(self, key, value):
self.trans.set_header(key, value)
def clear_headers(self):
self.trans.clear_headers()
def add_transform(self, transform_id):
self.trans.add_transform(transform_id)
def writeMessageBegin(self, name, ttype, seqid):
self.trans.sequence_id = seqid
return self._protocol.writeMessageBegin(name, ttype, seqid)
def writeMessageEnd(self):
return self._protocol.writeMessageEnd()
def writeStructBegin(self, name):
return self._protocol.writeStructBegin(name)
def writeStructEnd(self):
return self._protocol.writeStructEnd()
def writeFieldBegin(self, name, ttype, fid):
return self._protocol.writeFieldBegin(name, ttype, fid)
def writeFieldEnd(self):
return self._protocol.writeFieldEnd()
def writeFieldStop(self):
return self._protocol.writeFieldStop()
def writeMapBegin(self, ktype, vtype, size):
return self._protocol.writeMapBegin(ktype, vtype, size)
def writeMapEnd(self):
return self._protocol.writeMapEnd()
def writeListBegin(self, etype, size):
return self._protocol.writeListBegin(etype, size)
def writeListEnd(self):
return self._protocol.writeListEnd()
def writeSetBegin(self, etype, size):
return self._protocol.writeSetBegin(etype, size)
def writeSetEnd(self):
return self._protocol.writeSetEnd()
def writeBool(self, bool_val):
return self._protocol.writeBool(bool_val)
def writeByte(self, byte):
return self._protocol.writeByte(byte)
def writeI16(self, i16):
return self._protocol.writeI16(i16)
def writeI32(self, i32):
return self._protocol.writeI32(i32)
def writeI64(self, i64):
return self._protocol.writeI64(i64)
def writeDouble(self, dub):
return self._protocol.writeDouble(dub)
def writeBinary(self, str_val):
return self._protocol.writeBinary(str_val)
def _set_protocol(self):
try:
protocol_cls = PROTOCOLS_BY_ID[self.trans.protocol_id]
except KeyError:
raise TApplicationException(
TProtocolException.INVALID_PROTOCOL,
"Unknown protocol requested.",
)
self._protocol = protocol_cls(self.trans)
self._fast_encode = self._protocol._fast_encode
self._fast_decode = self._protocol._fast_decode
def readMessageBegin(self):
try:
self.trans.readFrame(0)
self._set_protocol()
except TApplicationException as exc:
self._protocol.writeMessageBegin(b"", TMessageType.EXCEPTION, 0)
exc.write(self._protocol)
self._protocol.writeMessageEnd()
self.trans.flush()
return self._protocol.readMessageBegin()
def readMessageEnd(self):
return self._protocol.readMessageEnd()
def readStructBegin(self):
return self._protocol.readStructBegin()
def readStructEnd(self):
return self._protocol.readStructEnd()
def readFieldBegin(self):
return self._protocol.readFieldBegin()
def readFieldEnd(self):
return self._protocol.readFieldEnd()
def readMapBegin(self):
return self._protocol.readMapBegin()
def readMapEnd(self):
return self._protocol.readMapEnd()
def readListBegin(self):
return self._protocol.readListBegin()
def readListEnd(self):
return self._protocol.readListEnd()
def readSetBegin(self):
return self._protocol.readSetBegin()
def readSetEnd(self):
return self._protocol.readSetEnd()
def readBool(self):
return self._protocol.readBool()
def readByte(self):
return self._protocol.readByte()
def readI16(self):
return self._protocol.readI16()
def readI32(self):
return self._protocol.readI32()
def readI64(self):
return self._protocol.readI64()
def readDouble(self):
return self._protocol.readDouble()
def readBinary(self):
return self._protocol.readBinary()
class THeaderProtocolFactory(object):
def __init__(self, allowed_client_types=(THeaderClientType.HEADERS,)):
self.allowed_client_types = allowed_client_types
def getProtocol(self, trans):
return THeaderProtocol(trans, self.allowed_client_types)

View file

@ -25,16 +25,15 @@ SEPARATOR = ":"
class TMultiplexedProtocol(TProtocolDecorator.TProtocolDecorator):
def __init__(self, protocol, serviceName):
TProtocolDecorator.TProtocolDecorator.__init__(self, protocol)
self.serviceName = serviceName
def writeMessageBegin(self, name, type, seqid):
if (type == TMessageType.CALL or
type == TMessageType.ONEWAY):
self.protocol.writeMessageBegin(
super(TMultiplexedProtocol, self).writeMessageBegin(
self.serviceName + SEPARATOR + name,
type,
seqid
)
else:
self.protocol.writeMessageBegin(name, type, seqid)
super(TMultiplexedProtocol, self).writeMessageBegin(name, type, seqid)

View file

@ -37,6 +37,7 @@ class TProtocolException(TException):
BAD_VERSION = 4
NOT_IMPLEMENTED = 5
DEPTH_LIMIT = 6
INVALID_PROTOCOL = 7
def __init__(self, type=UNKNOWN, message=None):
TException.__init__(self, message)
@ -268,7 +269,7 @@ class TProtocolBase(object):
return self._TTYPE_HANDLERS[ttype] if ttype < len(self._TTYPE_HANDLERS) else (None, None, False)
def _read_by_ttype(self, ttype, spec, espec):
reader_name, _, is_container = self._ttype_handlers(ttype, spec)
reader_name, _, is_container = self._ttype_handlers(ttype, espec)
if reader_name is None:
raise TProtocolException(type=TProtocolException.INVALID_DATA,
message='Invalid type %d' % (ttype))
@ -389,7 +390,7 @@ class TProtocolBase(object):
self.writeStructEnd()
def _write_by_ttype(self, ttype, vals, spec, espec):
_, writer_name, is_container = self._ttype_handlers(ttype, spec)
_, writer_name, is_container = self._ttype_handlers(ttype, espec)
writer_func = getattr(self, writer_name)
write = (lambda v: writer_func(v, espec)) if is_container else writer_func
for v in vals:

View file

@ -17,34 +17,10 @@
# under the License.
#
import types
from thrift.protocol.TProtocol import TProtocolBase
class TProtocolDecorator():
def __init__(self, protocol):
TProtocolBase(protocol)
self.protocol = protocol
def __getattr__(self, name):
if hasattr(self.protocol, name):
member = getattr(self.protocol, name)
if type(member) in [
types.MethodType,
types.FunctionType,
types.LambdaType,
types.BuiltinFunctionType,
types.BuiltinMethodType,
]:
return lambda *args, **kwargs: self._wrap(member, args, kwargs)
else:
return member
raise AttributeError(name)
def _wrap(self, func, args, kwargs):
if isinstance(func, types.MethodType):
result = func(*args, **kwargs)
else:
result = func(self.protocol, *args, **kwargs)
return result
class TProtocolDecorator(object):
def __new__(cls, protocol, *args, **kwargs):
decorated_cls = type(''.join(['Decorated', protocol.__class__.__name__]),
(cls, protocol.__class__),
protocol.__dict__)
return object.__new__(decorated_cls)

View file

@ -17,6 +17,8 @@
# under the License.
#
import ssl
from six.moves import BaseHTTPServer
from thrift.server import TServer
@ -47,11 +49,17 @@ class THttpServer(TServer.TServer):
server_address,
inputProtocolFactory,
outputProtocolFactory=None,
server_class=BaseHTTPServer.HTTPServer):
"""Set up protocol factories and HTTP server.
server_class=BaseHTTPServer.HTTPServer,
**kwargs):
"""Set up protocol factories and HTTP (or HTTPS) server.
See BaseHTTPServer for server_address.
See TServer for protocol factories.
To make a secure server, provide the named arguments:
* cafile - to validate clients [optional]
* cert_file - the server cert
* key_file - the server's key
"""
if outputProtocolFactory is None:
outputProtocolFactory = inputProtocolFactory
@ -83,5 +91,16 @@ class THttpServer(TServer.TServer):
self.httpd = server_class(server_address, RequestHander)
if (kwargs.get('cafile') or kwargs.get('cert_file') or kwargs.get('key_file')):
context = ssl.create_default_context(cafile=kwargs.get('cafile'))
context.check_hostname = False
context.load_cert_chain(kwargs.get('cert_file'), kwargs.get('key_file'))
context.verify_mode = ssl.CERT_REQUIRED if kwargs.get('cafile') else ssl.CERT_NONE
self.httpd.socket = context.wrap_socket(self.httpd.socket, server_side=True)
def serve(self):
self.httpd.serve_forever()
def shutdown(self):
self.httpd.socket.close()
# self.httpd.shutdown() # hangs forever, python doesn't handle POLLNVAL properly!

View file

@ -31,6 +31,7 @@ import socket
import struct
import threading
from collections import deque
from six.moves import queue
from thrift.transport import TTransport
@ -58,9 +59,10 @@ class Worker(threading.Thread):
processor.process(iprot, oprot)
callback(True, otrans.getvalue())
except Exception:
logger.exception("Exception while processing request")
logger.exception("Exception while processing request", exc_info=True)
callback(False, b'')
WAIT_LEN = 0
WAIT_MESSAGE = 1
WAIT_PROCESS = 2
@ -85,10 +87,23 @@ def socket_exception(func):
try:
return func(self, *args, **kwargs)
except socket.error:
logger.debug('ignoring socket exception', exc_info=True)
self.close()
return read
class Message(object):
def __init__(self, offset, len_, header):
self.offset = offset
self.len = len_
self.buffer = None
self.is_header = header
@property
def end(self):
return self.offset + self.len
class Connection(object):
"""Basic class is represented connection.
@ -106,68 +121,60 @@ class Connection(object):
self.socket.setblocking(False)
self.status = WAIT_LEN
self.len = 0
self.message = b''
self.received = deque()
self._reading = Message(0, 4, True)
self._rbuf = b''
self._wbuf = b''
self.lock = threading.Lock()
self.wake_up = wake_up
def _read_len(self):
"""Reads length of request.
It's a safer alternative to self.socket.recv(4)
"""
read = self.socket.recv(4 - len(self.message))
if len(read) == 0:
# if we read 0 bytes and self.message is empty, then
# the client closed the connection
if len(self.message) != 0:
logger.error("can't read frame size from socket")
self.close()
return
self.message += read
if len(self.message) == 4:
self.len, = struct.unpack('!i', self.message)
if self.len < 0:
logger.error("negative frame size, it seems client "
"doesn't use FramedTransport")
self.close()
elif self.len == 0:
logger.error("empty frame, it's really strange")
self.close()
else:
self.message = b''
self.status = WAIT_MESSAGE
self.remaining = False
@socket_exception
def read(self):
"""Reads data from stream and switch state."""
assert self.status in (WAIT_LEN, WAIT_MESSAGE)
if self.status == WAIT_LEN:
self._read_len()
# go back to the main loop here for simplicity instead of
# falling through, even though there is a good chance that
# the message is already available
elif self.status == WAIT_MESSAGE:
read = self.socket.recv(self.len - len(self.message))
if len(read) == 0:
logger.error("can't read frame from socket (get %d of "
"%d bytes)" % (len(self.message), self.len))
assert not self.received
buf_size = 8192
first = True
done = False
while not done:
read = self.socket.recv(buf_size)
rlen = len(read)
done = rlen < buf_size
self._rbuf += read
if first and rlen == 0:
if self.status != WAIT_LEN or self._rbuf:
logger.error('could not read frame from socket')
else:
logger.debug('read zero length. client might have disconnected')
self.close()
return
self.message += read
if len(self.message) == self.len:
while len(self._rbuf) >= self._reading.end:
if self._reading.is_header:
mlen, = struct.unpack('!i', self._rbuf[:4])
self._reading = Message(self._reading.end, mlen, False)
self.status = WAIT_MESSAGE
else:
self._reading.buffer = self._rbuf
self.received.append(self._reading)
self._rbuf = self._rbuf[self._reading.end:]
self._reading = Message(0, 4, True)
first = False
if self.received:
self.status = WAIT_PROCESS
break
self.remaining = not done
@socket_exception
def write(self):
"""Writes data from socket and switch state."""
assert self.status == SEND_ANSWER
sent = self.socket.send(self.message)
if sent == len(self.message):
sent = self.socket.send(self._wbuf)
if sent == len(self._wbuf):
self.status = WAIT_LEN
self.message = b''
self._wbuf = b''
self.len = 0
else:
self.message = self.message[sent:]
self._wbuf = self._wbuf[sent:]
@locked
def ready(self, all_ok, message):
@ -190,10 +197,10 @@ class Connection(object):
self.len = 0
if len(message) == 0:
# it was a oneway request, do not write answer
self.message = b''
self._wbuf = b''
self.status = WAIT_LEN
else:
self.message = struct.pack('!i', len(message)) + message
self._wbuf = struct.pack('!i', len(message)) + message
self.status = SEND_ANSWER
self.wake_up()
@ -292,14 +299,20 @@ class TNonblockingServer(object):
"""Does select on open connections."""
readable = [self.socket.handle.fileno(), self._read.fileno()]
writable = []
remaining = []
for i, connection in list(self.clients.items()):
if connection.is_readable():
readable.append(connection.fileno())
if connection.remaining or connection.received:
remaining.append(connection.fileno())
if connection.is_writeable():
writable.append(connection.fileno())
if connection.is_closed():
del self.clients[i]
return select.select(readable, writable, readable)
if remaining:
return remaining, [], [], False
else:
return select.select(readable, writable, readable) + (True,)
def handle(self):
"""Handle requests.
@ -307,20 +320,27 @@ class TNonblockingServer(object):
WARNING! You must call prepare() BEFORE calling handle()
"""
assert self.prepared, "You have to call prepare before handle"
rset, wset, xset = self._select()
rset, wset, xset, selected = self._select()
for readable in rset:
if readable == self._read.fileno():
# don't care i just need to clean readable flag
self._read.recv(1024)
elif readable == self.socket.handle.fileno():
client = self.socket.accept().handle
self.clients[client.fileno()] = Connection(client,
self.wake_up)
try:
client = self.socket.accept()
if client:
self.clients[client.handle.fileno()] = Connection(client.handle,
self.wake_up)
except socket.error:
logger.debug('error while accepting', exc_info=True)
else:
connection = self.clients[readable]
connection.read()
if connection.status == WAIT_PROCESS:
itransport = TTransport.TMemoryBuffer(connection.message)
if selected:
connection.read()
if connection.received:
connection.status = WAIT_PROCESS
msg = connection.received.popleft()
itransport = TTransport.TMemoryBuffer(msg.buffer, msg.offset)
otransport = TTransport.TMemoryBuffer()
iprot = self.in_protocol.getProtocol(itransport)
oprot = self.out_protocol.getProtocol(otransport)

View file

@ -23,6 +23,7 @@ import os
import threading
from thrift.protocol import TBinaryProtocol
from thrift.protocol.THeaderProtocol import THeaderProtocolFactory
from thrift.transport import TTransport
logger = logging.getLogger(__name__)
@ -60,6 +61,12 @@ class TServer(object):
self.inputProtocolFactory = inputProtocolFactory
self.outputProtocolFactory = outputProtocolFactory
input_is_header = isinstance(self.inputProtocolFactory, THeaderProtocolFactory)
output_is_header = isinstance(self.outputProtocolFactory, THeaderProtocolFactory)
if any((input_is_header, output_is_header)) and input_is_header != output_is_header:
raise ValueError("THeaderProtocol servers require that both the input and "
"output protocols are THeaderProtocol.")
def serve(self):
pass
@ -76,10 +83,20 @@ class TSimpleServer(TServer):
client = self.serverTransport.accept()
if not client:
continue
itrans = self.inputTransportFactory.getTransport(client)
otrans = self.outputTransportFactory.getTransport(client)
iprot = self.inputProtocolFactory.getProtocol(itrans)
oprot = self.outputProtocolFactory.getProtocol(otrans)
# for THeaderProtocol, we must use the same protocol instance for
# input and output so that the response is in the same dialect that
# the server detected the request was in.
if isinstance(self.inputProtocolFactory, THeaderProtocolFactory):
otrans = None
oprot = iprot
else:
otrans = self.outputTransportFactory.getTransport(client)
oprot = self.outputProtocolFactory.getProtocol(otrans)
try:
while True:
self.processor.process(iprot, oprot)
@ -89,7 +106,8 @@ class TSimpleServer(TServer):
logger.exception(x)
itrans.close()
otrans.close()
if otrans:
otrans.close()
class TThreadedServer(TServer):
@ -116,9 +134,18 @@ class TThreadedServer(TServer):
def handle(self, client):
itrans = self.inputTransportFactory.getTransport(client)
otrans = self.outputTransportFactory.getTransport(client)
iprot = self.inputProtocolFactory.getProtocol(itrans)
oprot = self.outputProtocolFactory.getProtocol(otrans)
# for THeaderProtocol, we must use the same protocol instance for input
# and output so that the response is in the same dialect that the
# server detected the request was in.
if isinstance(self.inputProtocolFactory, THeaderProtocolFactory):
otrans = None
oprot = iprot
else:
otrans = self.outputTransportFactory.getTransport(client)
oprot = self.outputProtocolFactory.getProtocol(otrans)
try:
while True:
self.processor.process(iprot, oprot)
@ -128,7 +155,8 @@ class TThreadedServer(TServer):
logger.exception(x)
itrans.close()
otrans.close()
if otrans:
otrans.close()
class TThreadPoolServer(TServer):
@ -156,9 +184,18 @@ class TThreadPoolServer(TServer):
def serveClient(self, client):
"""Process input/output from a client for as long as possible"""
itrans = self.inputTransportFactory.getTransport(client)
otrans = self.outputTransportFactory.getTransport(client)
iprot = self.inputProtocolFactory.getProtocol(itrans)
oprot = self.outputProtocolFactory.getProtocol(otrans)
# for THeaderProtocol, we must use the same protocol instance for input
# and output so that the response is in the same dialect that the
# server detected the request was in.
if isinstance(self.inputProtocolFactory, THeaderProtocolFactory):
otrans = None
oprot = iprot
else:
otrans = self.outputTransportFactory.getTransport(client)
oprot = self.outputProtocolFactory.getProtocol(otrans)
try:
while True:
self.processor.process(iprot, oprot)
@ -168,7 +205,8 @@ class TThreadPoolServer(TServer):
logger.exception(x)
itrans.close()
otrans.close()
if otrans:
otrans.close()
def serve(self):
"""Start a fixed number of worker threads and put client into a queue"""
@ -237,10 +275,18 @@ class TForkingServer(TServer):
try_close(otrans)
else:
itrans = self.inputTransportFactory.getTransport(client)
otrans = self.outputTransportFactory.getTransport(client)
iprot = self.inputProtocolFactory.getProtocol(itrans)
oprot = self.outputProtocolFactory.getProtocol(otrans)
# for THeaderProtocol, we must use the same protocol
# instance for input and output so that the response is in
# the same dialect that the server detected the request was
# in.
if isinstance(self.inputProtocolFactory, THeaderProtocolFactory):
otrans = None
oprot = iprot
else:
otrans = self.outputTransportFactory.getTransport(client)
oprot = self.outputProtocolFactory.getProtocol(otrans)
ecode = 0
try:
@ -254,7 +300,8 @@ class TForkingServer(TServer):
ecode = 1
finally:
try_close(itrans)
try_close(otrans)
if otrans:
try_close(otrans)
os._exit(ecode)

View file

@ -0,0 +1,352 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
import struct
import zlib
from thrift.compat import BufferIO, byte_index
from thrift.protocol.TBinaryProtocol import TBinaryProtocol
from thrift.protocol.TCompactProtocol import TCompactProtocol, readVarint, writeVarint
from thrift.Thrift import TApplicationException
from thrift.transport.TTransport import (
CReadableTransport,
TMemoryBuffer,
TTransportBase,
TTransportException,
)
U16 = struct.Struct("!H")
I32 = struct.Struct("!i")
HEADER_MAGIC = 0x0FFF
HARD_MAX_FRAME_SIZE = 0x3FFFFFFF
class THeaderClientType(object):
HEADERS = 0x00
FRAMED_BINARY = 0x01
UNFRAMED_BINARY = 0x02
FRAMED_COMPACT = 0x03
UNFRAMED_COMPACT = 0x04
class THeaderSubprotocolID(object):
BINARY = 0x00
COMPACT = 0x02
class TInfoHeaderType(object):
KEY_VALUE = 0x01
class THeaderTransformID(object):
ZLIB = 0x01
READ_TRANSFORMS_BY_ID = {
THeaderTransformID.ZLIB: zlib.decompress,
}
WRITE_TRANSFORMS_BY_ID = {
THeaderTransformID.ZLIB: zlib.compress,
}
def _readString(trans):
size = readVarint(trans)
if size < 0:
raise TTransportException(
TTransportException.NEGATIVE_SIZE,
"Negative length"
)
return trans.read(size)
def _writeString(trans, value):
writeVarint(trans, len(value))
trans.write(value)
class THeaderTransport(TTransportBase, CReadableTransport):
def __init__(self, transport, allowed_client_types):
self._transport = transport
self._client_type = THeaderClientType.HEADERS
self._allowed_client_types = allowed_client_types
self._read_buffer = BufferIO(b"")
self._read_headers = {}
self._write_buffer = BufferIO()
self._write_headers = {}
self._write_transforms = []
self.flags = 0
self.sequence_id = 0
self._protocol_id = THeaderSubprotocolID.BINARY
self._max_frame_size = HARD_MAX_FRAME_SIZE
def isOpen(self):
return self._transport.isOpen()
def open(self):
return self._transport.open()
def close(self):
return self._transport.close()
def get_headers(self):
return self._read_headers
def set_header(self, key, value):
if not isinstance(key, bytes):
raise ValueError("header names must be bytes")
if not isinstance(value, bytes):
raise ValueError("header values must be bytes")
self._write_headers[key] = value
def clear_headers(self):
self._write_headers.clear()
def add_transform(self, transform_id):
if transform_id not in WRITE_TRANSFORMS_BY_ID:
raise ValueError("unknown transform")
self._write_transforms.append(transform_id)
def set_max_frame_size(self, size):
if not 0 < size < HARD_MAX_FRAME_SIZE:
raise ValueError("maximum frame size should be < %d and > 0" % HARD_MAX_FRAME_SIZE)
self._max_frame_size = size
@property
def protocol_id(self):
if self._client_type == THeaderClientType.HEADERS:
return self._protocol_id
elif self._client_type in (THeaderClientType.FRAMED_BINARY, THeaderClientType.UNFRAMED_BINARY):
return THeaderSubprotocolID.BINARY
elif self._client_type in (THeaderClientType.FRAMED_COMPACT, THeaderClientType.UNFRAMED_COMPACT):
return THeaderSubprotocolID.COMPACT
else:
raise TTransportException(
TTransportException.INVALID_CLIENT_TYPE,
"Protocol ID not know for client type %d" % self._client_type,
)
def read(self, sz):
# if there are bytes left in the buffer, produce those first.
bytes_read = self._read_buffer.read(sz)
bytes_left_to_read = sz - len(bytes_read)
if bytes_left_to_read == 0:
return bytes_read
# if we've determined this is an unframed client, just pass the read
# through to the underlying transport until we're reset again at the
# beginning of the next message.
if self._client_type in (THeaderClientType.UNFRAMED_BINARY, THeaderClientType.UNFRAMED_COMPACT):
return bytes_read + self._transport.read(bytes_left_to_read)
# we're empty and (maybe) framed. fill the buffers with the next frame.
self.readFrame(bytes_left_to_read)
return bytes_read + self._read_buffer.read(bytes_left_to_read)
def _set_client_type(self, client_type):
if client_type not in self._allowed_client_types:
raise TTransportException(
TTransportException.INVALID_CLIENT_TYPE,
"Client type %d not allowed by server." % client_type,
)
self._client_type = client_type
def readFrame(self, req_sz):
# the first word could either be the length field of a framed message
# or the first bytes of an unframed message.
first_word = self._transport.readAll(I32.size)
frame_size, = I32.unpack(first_word)
is_unframed = False
if frame_size & TBinaryProtocol.VERSION_MASK == TBinaryProtocol.VERSION_1:
self._set_client_type(THeaderClientType.UNFRAMED_BINARY)
is_unframed = True
elif (byte_index(first_word, 0) == TCompactProtocol.PROTOCOL_ID and
byte_index(first_word, 1) & TCompactProtocol.VERSION_MASK == TCompactProtocol.VERSION):
self._set_client_type(THeaderClientType.UNFRAMED_COMPACT)
is_unframed = True
if is_unframed:
bytes_left_to_read = req_sz - I32.size
if bytes_left_to_read > 0:
rest = self._transport.read(bytes_left_to_read)
else:
rest = b""
self._read_buffer = BufferIO(first_word + rest)
return
# ok, we're still here so we're framed.
if frame_size > self._max_frame_size:
raise TTransportException(
TTransportException.SIZE_LIMIT,
"Frame was too large.",
)
read_buffer = BufferIO(self._transport.readAll(frame_size))
# the next word is either going to be the version field of a
# binary/compact protocol message or the magic value + flags of a
# header protocol message.
second_word = read_buffer.read(I32.size)
version, = I32.unpack(second_word)
read_buffer.seek(0)
if version >> 16 == HEADER_MAGIC:
self._set_client_type(THeaderClientType.HEADERS)
self._read_buffer = self._parse_header_format(read_buffer)
elif version & TBinaryProtocol.VERSION_MASK == TBinaryProtocol.VERSION_1:
self._set_client_type(THeaderClientType.FRAMED_BINARY)
self._read_buffer = read_buffer
elif (byte_index(second_word, 0) == TCompactProtocol.PROTOCOL_ID and
byte_index(second_word, 1) & TCompactProtocol.VERSION_MASK == TCompactProtocol.VERSION):
self._set_client_type(THeaderClientType.FRAMED_COMPACT)
self._read_buffer = read_buffer
else:
raise TTransportException(
TTransportException.INVALID_CLIENT_TYPE,
"Could not detect client transport type.",
)
def _parse_header_format(self, buffer):
# make BufferIO look like TTransport for varint helpers
buffer_transport = TMemoryBuffer()
buffer_transport._buffer = buffer
buffer.read(2) # discard the magic bytes
self.flags, = U16.unpack(buffer.read(U16.size))
self.sequence_id, = I32.unpack(buffer.read(I32.size))
header_length = U16.unpack(buffer.read(U16.size))[0] * 4
end_of_headers = buffer.tell() + header_length
if end_of_headers > len(buffer.getvalue()):
raise TTransportException(
TTransportException.SIZE_LIMIT,
"Header size is larger than whole frame.",
)
self._protocol_id = readVarint(buffer_transport)
transforms = []
transform_count = readVarint(buffer_transport)
for _ in range(transform_count):
transform_id = readVarint(buffer_transport)
if transform_id not in READ_TRANSFORMS_BY_ID:
raise TApplicationException(
TApplicationException.INVALID_TRANSFORM,
"Unknown transform: %d" % transform_id,
)
transforms.append(transform_id)
transforms.reverse()
headers = {}
while buffer.tell() < end_of_headers:
header_type = readVarint(buffer_transport)
if header_type == TInfoHeaderType.KEY_VALUE:
count = readVarint(buffer_transport)
for _ in range(count):
key = _readString(buffer_transport)
value = _readString(buffer_transport)
headers[key] = value
else:
break # ignore unknown headers
self._read_headers = headers
# skip padding / anything we didn't understand
buffer.seek(end_of_headers)
payload = buffer.read()
for transform_id in transforms:
transform_fn = READ_TRANSFORMS_BY_ID[transform_id]
payload = transform_fn(payload)
return BufferIO(payload)
def write(self, buf):
self._write_buffer.write(buf)
def flush(self):
payload = self._write_buffer.getvalue()
self._write_buffer = BufferIO()
buffer = BufferIO()
if self._client_type == THeaderClientType.HEADERS:
for transform_id in self._write_transforms:
transform_fn = WRITE_TRANSFORMS_BY_ID[transform_id]
payload = transform_fn(payload)
headers = BufferIO()
writeVarint(headers, self._protocol_id)
writeVarint(headers, len(self._write_transforms))
for transform_id in self._write_transforms:
writeVarint(headers, transform_id)
if self._write_headers:
writeVarint(headers, TInfoHeaderType.KEY_VALUE)
writeVarint(headers, len(self._write_headers))
for key, value in self._write_headers.items():
_writeString(headers, key)
_writeString(headers, value)
self._write_headers = {}
padding_needed = (4 - (len(headers.getvalue()) % 4)) % 4
headers.write(b"\x00" * padding_needed)
header_bytes = headers.getvalue()
buffer.write(I32.pack(10 + len(header_bytes) + len(payload)))
buffer.write(U16.pack(HEADER_MAGIC))
buffer.write(U16.pack(self.flags))
buffer.write(I32.pack(self.sequence_id))
buffer.write(U16.pack(len(header_bytes) // 4))
buffer.write(header_bytes)
buffer.write(payload)
elif self._client_type in (THeaderClientType.FRAMED_BINARY, THeaderClientType.FRAMED_COMPACT):
buffer.write(I32.pack(len(payload)))
buffer.write(payload)
elif self._client_type in (THeaderClientType.UNFRAMED_BINARY, THeaderClientType.UNFRAMED_COMPACT):
buffer.write(payload)
else:
raise TTransportException(
TTransportException.INVALID_CLIENT_TYPE,
"Unknown client type.",
)
# the frame length field doesn't count towards the frame payload size
frame_bytes = buffer.getvalue()
frame_payload_size = len(frame_bytes) - 4
if frame_payload_size > self._max_frame_size:
raise TTransportException(
TTransportException.SIZE_LIMIT,
"Attempting to send frame that is too large.",
)
self._transport.write(frame_bytes)
self._transport.flush()
@property
def cstringio_buf(self):
return self._read_buffer
def cstringio_refill(self, partialread, reqlen):
result = bytearray(partialread)
while len(result) < reqlen:
result += self.read(reqlen - len(result))
self._read_buffer = BufferIO(result)
return self._read_buffer

View file

@ -19,7 +19,7 @@
from io import BytesIO
import os
import socket
import ssl
import sys
import warnings
import base64
@ -34,17 +34,20 @@ import six
class THttpClient(TTransportBase):
"""Http implementation of TTransport base."""
def __init__(self, uri_or_host, port=None, path=None):
"""THttpClient supports two different types constructor parameters.
def __init__(self, uri_or_host, port=None, path=None, cafile=None, cert_file=None, key_file=None, ssl_context=None):
"""THttpClient supports two different types of construction:
THttpClient(host, port, path) - deprecated
THttpClient(uri)
THttpClient(uri, [port=<n>, path=<s>, cafile=<filename>, cert_file=<filename>, key_file=<filename>, ssl_context=<context>])
Only the second supports https.
Only the second supports https. To properly authenticate against the server,
provide the client's identity by specifying cert_file and key_file. To properly
authenticate the server, specify either cafile or ssl_context with a CA defined.
NOTE: if both cafile and ssl_context are defined, ssl_context will override cafile.
"""
if port is not None:
warnings.warn(
"Please use the THttpClient('http://host:port/path') syntax",
"Please use the THttpClient('http{s}://host:port/path') constructor",
DeprecationWarning,
stacklevel=2)
self.host = uri_or_host
@ -60,6 +63,9 @@ class THttpClient(TTransportBase):
self.port = parsed.port or http_client.HTTP_PORT
elif self.scheme == 'https':
self.port = parsed.port or http_client.HTTPS_PORT
self.certfile = cert_file
self.keyfile = key_file
self.context = ssl.create_default_context(cafile=cafile) if (cafile and not ssl_context) else ssl_context
self.host = parsed.hostname
self.path = parsed.path
if parsed.query:
@ -100,12 +106,17 @@ class THttpClient(TTransportBase):
def open(self):
if self.scheme == 'http':
self.__http = http_client.HTTPConnection(self.host, self.port)
self.__http = http_client.HTTPConnection(self.host, self.port,
timeout=self.__timeout)
elif self.scheme == 'https':
self.__http = http_client.HTTPSConnection(self.host, self.port)
if self.using_proxy():
self.__http.set_tunnel(self.realhost, self.realport,
{"Proxy-Authorization": self.proxy_auth})
self.__http = http_client.HTTPSConnection(self.host, self.port,
key_file=self.keyfile,
cert_file=self.certfile,
timeout=self.__timeout,
context=self.context)
if self.using_proxy():
self.__http.set_tunnel(self.realhost, self.realport,
{"Proxy-Authorization": self.proxy_auth})
def close(self):
self.__http.close()
@ -116,9 +127,6 @@ class THttpClient(TTransportBase):
return self.__http is not None
def setTimeout(self, ms):
if not hasattr(socket, 'getdefaulttimeout'):
raise NotImplementedError
if ms is None:
self.__timeout = None
else:
@ -133,17 +141,6 @@ class THttpClient(TTransportBase):
def write(self, buf):
self.__wbuf.write(buf)
def __withTimeout(f):
def _f(*args, **kwargs):
orig_timeout = socket.getdefaulttimeout()
socket.setdefaulttimeout(args[0].__timeout)
try:
result = f(*args, **kwargs)
finally:
socket.setdefaulttimeout(orig_timeout)
return result
return _f
def flush(self):
if self.isOpen():
self.close()
@ -188,7 +185,3 @@ class THttpClient(TTransportBase):
self.code = self.__http_response.status
self.message = self.__http_response.reason
self.headers = self.__http_response.msg
# Decorate if we know how to timeout
if hasattr(socket, 'getdefaulttimeout'):
flush = __withTimeout(flush)

View file

@ -40,10 +40,10 @@ class TSSLBase(object):
# ciphers argument is not available for Python < 2.7.0
_has_ciphers = sys.hexversion >= 0x020700F0
# For pythoon >= 2.7.9, use latest TLS that both client and server
# For python >= 2.7.9, use latest TLS that both client and server
# supports.
# SSL 2.0 and 3.0 are disabled via ssl.OP_NO_SSLv2 and ssl.OP_NO_SSLv3.
# For pythoon < 2.7.9, use TLS 1.0 since TLSv1_X nor OP_NO_SSLvX is
# For python < 2.7.9, use TLS 1.0 since TLSv1_X nor OP_NO_SSLvX is
# unavailable.
_default_protocol = ssl.PROTOCOL_SSLv23 if _has_ssl_context else \
ssl.PROTOCOL_TLSv1
@ -368,7 +368,7 @@ class TSSLServerSocket(TSocket.TServerSocket, TSSLBase):
plain_client, addr = self.handle.accept()
try:
client = self._wrap_socket(plain_client)
except ssl.SSLError:
except (ssl.SSLError, socket.error, OSError):
logger.exception('Error while accepting from %s', addr)
# failed handshake/ssl wrap, close socket to client
plain_client.close()

View file

@ -159,6 +159,15 @@ class TServerSocket(TSocketBase, TServerTransportBase):
self._unix_socket = unix_socket
self._socket_family = socket_family
self.handle = None
self._backlog = 128
def setBacklog(self, backlog=None):
if not self.handle:
self._backlog = backlog
else:
# We cann't update backlog when it is already listening, since the
# handle has been created.
logger.warn('You have to set backlog before listen.')
def listen(self):
res0 = self._resolveAddr()
@ -183,7 +192,7 @@ class TServerSocket(TSocketBase, TServerTransportBase):
if hasattr(self.handle, 'settimeout'):
self.handle.settimeout(None)
self.handle.bind(res[4])
self.handle.listen(128)
self.handle.listen(self._backlog)
def accept(self):
client, addr = self.handle.accept()

View file

@ -32,6 +32,7 @@ class TTransportException(TException):
END_OF_FILE = 4
NEGATIVE_SIZE = 5
SIZE_LIMIT = 6
INVALID_CLIENT_TYPE = 7
def __init__(self, type=UNKNOWN, message=None):
TException.__init__(self, message)
@ -58,10 +59,11 @@ class TTransportBase(object):
have = 0
while (have < sz):
chunk = self.read(sz - have)
have += len(chunk)
chunkLen = len(chunk)
have += chunkLen
buff += chunk
if len(chunk) == 0:
if chunkLen == 0:
raise EOFError()
return buff
@ -168,7 +170,6 @@ class TBufferedTransport(TTransportBase, CReadableTransport):
# on exception reset wbuf so it doesn't contain a partial function call
self.__wbuf = BufferIO()
raise e
self.__wbuf.getvalue()
def flush(self):
out = self.__wbuf.getvalue()
@ -205,7 +206,7 @@ class TMemoryBuffer(TTransportBase, CReadableTransport):
TODO(dreiss): Make this work like the C++ version.
"""
def __init__(self, value=None):
def __init__(self, value=None, offset=0):
"""value -- a value to read from for stringio
If value is set, this will be a transport for reading,
@ -214,6 +215,8 @@ class TMemoryBuffer(TTransportBase, CReadableTransport):
self._buffer = BufferIO(value)
else:
self._buffer = BufferIO()
if offset:
self._buffer.seek(offset)
def isOpen(self):
return not self._buffer.closed

View file

@ -20,7 +20,7 @@
from io import BytesIO
import struct
from zope.interface import implements, Interface, Attribute
from zope.interface import implementer, Interface, Attribute
from twisted.internet.protocol import ServerFactory, ClientFactory, \
connectionDone
from twisted.internet import defer
@ -257,10 +257,9 @@ class IThriftClientFactory(Interface):
oprot_factory = Attribute("Output protocol factory")
@implementer(IThriftServerFactory)
class ThriftServerFactory(ServerFactory):
implements(IThriftServerFactory)
protocol = ThriftServerProtocol
def __init__(self, processor, iprot_factory, oprot_factory=None):
@ -272,10 +271,9 @@ class ThriftServerFactory(ServerFactory):
self.oprot_factory = oprot_factory
@implementer(IThriftClientFactory)
class ThriftClientFactory(ClientFactory):
implements(IThriftClientFactory)
protocol = ThriftClientProtocol
def __init__(self, client_class, iprot_factory, oprot_factory=None):

View file

@ -25,7 +25,7 @@ from thrift.transport.TTransport import TTransportException
logger = logging.getLogger(__name__)
def legacy_validate_callback(self, cert, hostname):
def legacy_validate_callback(cert, hostname):
"""legacy method to validate the peer's SSL certificate, and to check
the commonName of the certificate to ensure it matches the hostname we
used to make this connection. Does not support subjectAltName records
@ -36,7 +36,7 @@ def legacy_validate_callback(self, cert, hostname):
if 'subject' not in cert:
raise TTransportException(
TTransportException.NOT_OPEN,
'No SSL certificate found from %s:%s' % (self.host, self.port))
'No SSL certificate found from %s' % hostname)
fields = cert['subject']
for field in fields:
# ensure structure we get back is what we expect
@ -57,7 +57,7 @@ def legacy_validate_callback(self, cert, hostname):
raise TTransportException(
TTransportException.UNKNOWN,
'Hostname we connected to "%s" doesn\'t match certificate '
'provided commonName "%s"' % (self.host, certhost))
'provided commonName "%s"' % (hostname, certhost))
raise TTransportException(
TTransportException.UNKNOWN,
'Could not validate SSL certificate from host "%s". Cert=%s'
@ -96,4 +96,5 @@ def _optional_dependencies():
match = legacy_validate_callback
return ipaddr, match
_match_has_ipaddress, _match_hostname = _optional_dependencies()

View file

@ -42,7 +42,7 @@ CLIENT_CERT = os.path.join(ROOT_DIR, 'test', 'keys', 'client_v3.crt')
CLIENT_KEY = os.path.join(ROOT_DIR, 'test', 'keys', 'client_v3.key')
CLIENT_CA = os.path.join(ROOT_DIR, 'test', 'keys', 'CA.pem')
TEST_CIPHERS = 'DES-CBC3-SHA'
TEST_CIPHERS = 'DES-CBC3-SHA:ECDHE-RSA-AES128-GCM-SHA256'
class ServerAcceptor(threading.Thread):
@ -95,6 +95,11 @@ class ServerAcceptor(threading.Thread):
self._client_accepted.wait()
return self._client
def close(self):
if self._client:
self._client.close()
self._server.close()
# Python 2.6 compat
class AssertRaises(object):
@ -125,9 +130,7 @@ class TSSLSocketTest(unittest.TestCase):
client = TSSLSocket(host, port, unix_socket=path, **client_kwargs)
yield acc, client
finally:
if acc.client:
acc.client.close()
server.close()
acc.close()
def _assert_connection_failure(self, server, path=None, **client_args):
logging.disable(logging.CRITICAL)
@ -237,6 +240,9 @@ class TSSLSocketTest(unittest.TestCase):
self._assert_connection_success(server, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT)
def test_client_cert(self):
if not _match_has_ipaddress:
print('skipping test_client_cert')
return
server = self._server_socket(
cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY,
certfile=SERVER_CERT, ca_certs=CLIENT_CERT)
@ -331,9 +337,10 @@ class TSSLSocketTest(unittest.TestCase):
self._assert_connection_success(server, ssl_context=client_context)
if __name__ == '__main__':
logging.basicConfig(level=logging.WARN)
from thrift.transport.TSSLSocket import TSSLSocket, TSSLServerSocket
from thrift.transport.TSSLSocket import TSSLSocket, TSSLServerSocket, _match_has_ipaddress
from thrift.transport.TTransport import TTransportException
unittest.main()

View file

@ -46,5 +46,6 @@ class TestJSONString(unittest.TestCase):
unicode_text = unicode_text.encode('utf8')
self.assertEqual(protocol.readString(), unicode_text)
if __name__ == '__main__':
unittest.main()