Checking in vendor folder for ease of using go get.

This commit is contained in:
Renan DelValle 2018-10-23 23:32:59 -07:00
parent 7a1251853b
commit cdb4b5a1d0
No known key found for this signature in database
GPG key ID: C240AD6D6F443EC9
3554 changed files with 1270116 additions and 0 deletions

31
vendor/git.apache.org/thrift.git/lib/py/CMakeLists.txt generated vendored Normal file
View file

@ -0,0 +1,31 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
include_directories(${PYTHON_INCLUDE_DIRS})
add_custom_target(python_build ALL
COMMAND ${PYTHON_EXECUTABLE} setup.py build
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
COMMENT "Building Python library"
)
if(BUILD_TESTING)
add_test(PythonTestSSLSocket ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/test/test_sslsocket.py)
add_test(PythonThriftJson ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/test/thrift_json.py)
endif()

58
vendor/git.apache.org/thrift.git/lib/py/Makefile.am generated vendored Normal file
View file

@ -0,0 +1,58 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
AUTOMAKE_OPTIONS = serial-tests
DESTDIR ?= /
if WITH_PY3
py3-build:
$(PYTHON3) setup.py build
py3-test: py3-build
$(PYTHON3) test/thrift_json.py
$(PYTHON3) test/test_sslsocket.py
else
py3-build:
py3-test:
endif
all-local: py3-build
$(PYTHON) setup.py build
# We're ignoring prefix here because site-packages seems to be
# the equivalent of /usr/local/lib in Python land.
# Old version (can't put inline because it's not portable).
#$(PYTHON) setup.py install --prefix=$(prefix) --root=$(DESTDIR) $(PYTHON_SETUPUTIL_ARGS)
install-exec-hook:
$(PYTHON) setup.py install --root=$(DESTDIR) --prefix=$(PY_PREFIX) $(PYTHON_SETUPUTIL_ARGS)
clean-local:
$(RM) -r build
check-local: all py3-test
$(PYTHON) test/thrift_json.py
$(PYTHON) test/test_sslsocket.py
EXTRA_DIST = \
CMakeLists.txt \
coding_standards.md \
compat \
setup.py \
setup.cfg \
src \
test \
README.md

35
vendor/git.apache.org/thrift.git/lib/py/README.md generated vendored Normal file
View file

@ -0,0 +1,35 @@
Thrift Python Software Library
License
=======
Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
Using Thrift with Python
========================
Thrift is provided as a set of Python packages. The top level package is
thrift, and there are subpackages for the protocol, transport, and server
code. Each package contains modules using standard Thrift naming conventions
(i.e. TProtocol, TTransport) and implementations in corresponding modules
(i.e. TSocket). There is also a subpackage reflection, which contains
the generated code for the reflection structures.
The Python libraries can be installed manually using the provided setup.py
file, or automatically using the install hook provided via autoconf/automake.
To use the latter, become superuser and do make install.

View file

@ -0,0 +1,7 @@
## Python Coding Standards
Please follow:
* [Thrift General Coding Standards](/doc/coding_standards.md)
* Code Style for Python Code [PEP8](http://legacy.python.org/dev/peps/pep-0008/)
When in doubt - check with <http://www.pylint.org/> or online with <http://pep8online.com>.

View file

@ -0,0 +1,247 @@
// ISO C9x compliant stdint.h for Microsoft Visual Studio
// Based on ISO/IEC 9899:TC2 Committee draft (May 6, 2005) WG14/N1124
//
// Copyright (c) 2006-2008 Alexander Chemeris
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
//
// 3. The name of the author may be used to endorse or promote products
// derived from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED
// WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
// EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
// OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
// OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
// ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//
///////////////////////////////////////////////////////////////////////////////
#ifndef _MSC_VER // [
#error "Use this header only with Microsoft Visual C++ compilers!"
#endif // _MSC_VER ]
#ifndef _MSC_STDINT_H_ // [
#define _MSC_STDINT_H_
#if _MSC_VER > 1000
#pragma once
#endif
#include <limits.h>
// For Visual Studio 6 in C++ mode and for many Visual Studio versions when
// compiling for ARM we should wrap <wchar.h> include with 'extern "C++" {}'
// or compiler give many errors like this:
// error C2733: second C linkage of overloaded function 'wmemchr' not allowed
#ifdef __cplusplus
extern "C" {
#endif
# include <wchar.h>
#ifdef __cplusplus
}
#endif
// Define _W64 macros to mark types changing their size, like intptr_t.
#ifndef _W64
# if !defined(__midl) && (defined(_X86_) || defined(_M_IX86)) && _MSC_VER >= 1300
# define _W64 __w64
# else
# define _W64
# endif
#endif
// 7.18.1 Integer types
// 7.18.1.1 Exact-width integer types
// Visual Studio 6 and Embedded Visual C++ 4 doesn't
// realize that, e.g. char has the same size as __int8
// so we give up on __intX for them.
#if (_MSC_VER < 1300)
typedef signed char int8_t;
typedef signed short int16_t;
typedef signed int int32_t;
typedef unsigned char uint8_t;
typedef unsigned short uint16_t;
typedef unsigned int uint32_t;
#else
typedef signed __int8 int8_t;
typedef signed __int16 int16_t;
typedef signed __int32 int32_t;
typedef unsigned __int8 uint8_t;
typedef unsigned __int16 uint16_t;
typedef unsigned __int32 uint32_t;
#endif
typedef signed __int64 int64_t;
typedef unsigned __int64 uint64_t;
// 7.18.1.2 Minimum-width integer types
typedef int8_t int_least8_t;
typedef int16_t int_least16_t;
typedef int32_t int_least32_t;
typedef int64_t int_least64_t;
typedef uint8_t uint_least8_t;
typedef uint16_t uint_least16_t;
typedef uint32_t uint_least32_t;
typedef uint64_t uint_least64_t;
// 7.18.1.3 Fastest minimum-width integer types
typedef int8_t int_fast8_t;
typedef int16_t int_fast16_t;
typedef int32_t int_fast32_t;
typedef int64_t int_fast64_t;
typedef uint8_t uint_fast8_t;
typedef uint16_t uint_fast16_t;
typedef uint32_t uint_fast32_t;
typedef uint64_t uint_fast64_t;
// 7.18.1.4 Integer types capable of holding object pointers
#ifdef _WIN64 // [
typedef signed __int64 intptr_t;
typedef unsigned __int64 uintptr_t;
#else // _WIN64 ][
typedef _W64 signed int intptr_t;
typedef _W64 unsigned int uintptr_t;
#endif // _WIN64 ]
// 7.18.1.5 Greatest-width integer types
typedef int64_t intmax_t;
typedef uint64_t uintmax_t;
// 7.18.2 Limits of specified-width integer types
#if !defined(__cplusplus) || defined(__STDC_LIMIT_MACROS) // [ See footnote 220 at page 257 and footnote 221 at page 259
// 7.18.2.1 Limits of exact-width integer types
#define INT8_MIN ((int8_t)_I8_MIN)
#define INT8_MAX _I8_MAX
#define INT16_MIN ((int16_t)_I16_MIN)
#define INT16_MAX _I16_MAX
#define INT32_MIN ((int32_t)_I32_MIN)
#define INT32_MAX _I32_MAX
#define INT64_MIN ((int64_t)_I64_MIN)
#define INT64_MAX _I64_MAX
#define UINT8_MAX _UI8_MAX
#define UINT16_MAX _UI16_MAX
#define UINT32_MAX _UI32_MAX
#define UINT64_MAX _UI64_MAX
// 7.18.2.2 Limits of minimum-width integer types
#define INT_LEAST8_MIN INT8_MIN
#define INT_LEAST8_MAX INT8_MAX
#define INT_LEAST16_MIN INT16_MIN
#define INT_LEAST16_MAX INT16_MAX
#define INT_LEAST32_MIN INT32_MIN
#define INT_LEAST32_MAX INT32_MAX
#define INT_LEAST64_MIN INT64_MIN
#define INT_LEAST64_MAX INT64_MAX
#define UINT_LEAST8_MAX UINT8_MAX
#define UINT_LEAST16_MAX UINT16_MAX
#define UINT_LEAST32_MAX UINT32_MAX
#define UINT_LEAST64_MAX UINT64_MAX
// 7.18.2.3 Limits of fastest minimum-width integer types
#define INT_FAST8_MIN INT8_MIN
#define INT_FAST8_MAX INT8_MAX
#define INT_FAST16_MIN INT16_MIN
#define INT_FAST16_MAX INT16_MAX
#define INT_FAST32_MIN INT32_MIN
#define INT_FAST32_MAX INT32_MAX
#define INT_FAST64_MIN INT64_MIN
#define INT_FAST64_MAX INT64_MAX
#define UINT_FAST8_MAX UINT8_MAX
#define UINT_FAST16_MAX UINT16_MAX
#define UINT_FAST32_MAX UINT32_MAX
#define UINT_FAST64_MAX UINT64_MAX
// 7.18.2.4 Limits of integer types capable of holding object pointers
#ifdef _WIN64 // [
# define INTPTR_MIN INT64_MIN
# define INTPTR_MAX INT64_MAX
# define UINTPTR_MAX UINT64_MAX
#else // _WIN64 ][
# define INTPTR_MIN INT32_MIN
# define INTPTR_MAX INT32_MAX
# define UINTPTR_MAX UINT32_MAX
#endif // _WIN64 ]
// 7.18.2.5 Limits of greatest-width integer types
#define INTMAX_MIN INT64_MIN
#define INTMAX_MAX INT64_MAX
#define UINTMAX_MAX UINT64_MAX
// 7.18.3 Limits of other integer types
#ifdef _WIN64 // [
# define PTRDIFF_MIN _I64_MIN
# define PTRDIFF_MAX _I64_MAX
#else // _WIN64 ][
# define PTRDIFF_MIN _I32_MIN
# define PTRDIFF_MAX _I32_MAX
#endif // _WIN64 ]
#define SIG_ATOMIC_MIN INT_MIN
#define SIG_ATOMIC_MAX INT_MAX
#ifndef SIZE_MAX // [
# ifdef _WIN64 // [
# define SIZE_MAX _UI64_MAX
# else // _WIN64 ][
# define SIZE_MAX _UI32_MAX
# endif // _WIN64 ]
#endif // SIZE_MAX ]
// WCHAR_MIN and WCHAR_MAX are also defined in <wchar.h>
#ifndef WCHAR_MIN // [
# define WCHAR_MIN 0
#endif // WCHAR_MIN ]
#ifndef WCHAR_MAX // [
# define WCHAR_MAX _UI16_MAX
#endif // WCHAR_MAX ]
#define WINT_MIN 0
#define WINT_MAX _UI16_MAX
#endif // __STDC_LIMIT_MACROS ]
// 7.18.4 Limits of other integer types
#if !defined(__cplusplus) || defined(__STDC_CONSTANT_MACROS) // [ See footnote 224 at page 260
// 7.18.4.1 Macros for minimum-width integer constants
#define INT8_C(val) val##i8
#define INT16_C(val) val##i16
#define INT32_C(val) val##i32
#define INT64_C(val) val##i64
#define UINT8_C(val) val##ui8
#define UINT16_C(val) val##ui16
#define UINT32_C(val) val##ui32
#define UINT64_C(val) val##ui64
// 7.18.4.2 Macros for greatest-width integer constants
#define INTMAX_C INT64_C
#define UINTMAX_C UINT64_C
#endif // __STDC_CONSTANT_MACROS ]
#endif // _MSC_STDINT_H_ ]

6
vendor/git.apache.org/thrift.git/lib/py/setup.cfg generated vendored Normal file
View file

@ -0,0 +1,6 @@
[install]
optimize = 1
[metadata]
description-file = README.md
[flake8]
max-line-length = 100

134
vendor/git.apache.org/thrift.git/lib/py/setup.py generated vendored Normal file
View file

@ -0,0 +1,134 @@
#!/usr/bin/env python
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
import sys
try:
from setuptools import setup, Extension
except:
from distutils.core import setup, Extension
from distutils.command.build_ext import build_ext
from distutils.errors import CCompilerError, DistutilsExecError, DistutilsPlatformError
# Fix to build sdist under vagrant
import os
if 'vagrant' in str(os.environ):
del os.link
include_dirs = ['src']
if sys.platform == 'win32':
include_dirs.append('compat/win32')
ext_errors = (CCompilerError, DistutilsExecError, DistutilsPlatformError, IOError)
else:
ext_errors = (CCompilerError, DistutilsExecError, DistutilsPlatformError)
class BuildFailed(Exception):
pass
class ve_build_ext(build_ext):
def run(self):
try:
build_ext.run(self)
except DistutilsPlatformError:
raise BuildFailed()
def build_extension(self, ext):
try:
build_ext.build_extension(self, ext)
except ext_errors:
raise BuildFailed()
def run_setup(with_binary):
if with_binary:
extensions = dict(
ext_modules=[
Extension('thrift.protocol.fastbinary',
sources=[
'src/ext/module.cpp',
'src/ext/types.cpp',
'src/ext/binary.cpp',
'src/ext/compact.cpp',
],
include_dirs=include_dirs,
)
],
cmdclass=dict(build_ext=ve_build_ext)
)
else:
extensions = dict()
ssl_deps = []
if sys.version_info[0] == 2:
ssl_deps.append('ipaddress')
if sys.hexversion < 0x03050000:
ssl_deps.append('backports.ssl_match_hostname>=3.5')
tornado_deps = ['tornado>=4.0']
twisted_deps = ['twisted']
setup(name='thrift',
version='0.10.0',
description='Python bindings for the Apache Thrift RPC system',
author='Thrift Developers',
author_email='dev@thrift.apache.org',
url='http://thrift.apache.org',
license='Apache License 2.0',
install_requires=['six>=1.7.2'],
extras_require={
'ssl': ssl_deps,
'tornado': tornado_deps,
'twisted': twisted_deps,
'all': ssl_deps + tornado_deps + twisted_deps,
},
packages=[
'thrift',
'thrift.protocol',
'thrift.transport',
'thrift.server',
],
package_dir={'thrift': 'src'},
classifiers=[
'Development Status :: 5 - Production/Stable',
'Environment :: Console',
'Intended Audience :: Developers',
'Programming Language :: Python',
'Programming Language :: Python :: 2',
'Programming Language :: Python :: 3',
'Topic :: Software Development :: Libraries',
'Topic :: System :: Networking'
],
**extensions
)
try:
with_binary = True
run_setup(with_binary)
except BuildFailed:
print()
print('*' * 80)
print("An error occurred while trying to compile with the C extension enabled")
print("Attempting to build without the extension now")
print('*' * 80)
print()
run_setup(False)

View file

@ -0,0 +1,55 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
from thrift.Thrift import TProcessor, TMessageType, TException
from thrift.protocol import TProtocolDecorator, TMultiplexedProtocol
class TMultiplexedProcessor(TProcessor):
def __init__(self):
self.services = {}
def registerProcessor(self, serviceName, processor):
self.services[serviceName] = processor
def process(self, iprot, oprot):
(name, type, seqid) = iprot.readMessageBegin()
if type != TMessageType.CALL and type != TMessageType.ONEWAY:
raise TException("TMultiplex protocol only supports CALL & ONEWAY")
index = name.find(TMultiplexedProtocol.SEPARATOR)
if index < 0:
raise TException("Service name not found in message name: " + name + ". Did you forget to use TMultiplexProtocol in your client?")
serviceName = name[0:index]
call = name[index + len(TMultiplexedProtocol.SEPARATOR):]
if serviceName not in self.services:
raise TException("Service name not found: " + serviceName + ". Did you forget to call registerProcessor()?")
standardMessage = (call, type, seqid)
return self.services[serviceName].process(StoredMessageProtocol(iprot, standardMessage), oprot)
class StoredMessageProtocol(TProtocolDecorator.TProtocolDecorator):
def __init__(self, protocol, messageBegin):
TProtocolDecorator.TProtocolDecorator.__init__(self, protocol)
self.messageBegin = messageBegin
def readMessageBegin(self):
return self.messageBegin

36
vendor/git.apache.org/thrift.git/lib/py/src/TSCons.py generated vendored Normal file
View file

@ -0,0 +1,36 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
from os import path
from SCons.Builder import Builder
from six.moves import map
def scons_env(env, add=''):
opath = path.dirname(path.abspath('$TARGET'))
lstr = 'thrift --gen cpp -o ' + opath + ' ' + add + ' $SOURCE'
cppbuild = Builder(action=lstr)
env.Append(BUILDERS={'ThriftCpp': cppbuild})
def gen_cpp(env, dir, file):
scons_env(env)
suffixes = ['_types.h', '_types.cpp']
targets = map(lambda s: 'gen-cpp/' + file + s, suffixes)
return env.ThriftCpp(targets, dir + file + '.thrift')

View file

@ -0,0 +1,38 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
from .protocol import TBinaryProtocol
from .transport import TTransport
def serialize(thrift_object,
protocol_factory=TBinaryProtocol.TBinaryProtocolFactory()):
transport = TTransport.TMemoryBuffer()
protocol = protocol_factory.getProtocol(transport)
thrift_object.write(protocol)
return transport.getvalue()
def deserialize(base,
buf,
protocol_factory=TBinaryProtocol.TBinaryProtocolFactory()):
transport = TTransport.TMemoryBuffer(buf)
protocol = protocol_factory.getProtocol(transport)
base.read(protocol)
return base

188
vendor/git.apache.org/thrift.git/lib/py/src/TTornado.py generated vendored Normal file
View file

@ -0,0 +1,188 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
from __future__ import absolute_import
import logging
import socket
import struct
from .transport.TTransport import TTransportException, TTransportBase, TMemoryBuffer
from io import BytesIO
from collections import deque
from contextlib import contextmanager
from tornado import gen, iostream, ioloop, tcpserver, concurrent
__all__ = ['TTornadoServer', 'TTornadoStreamTransport']
logger = logging.getLogger(__name__)
class _Lock(object):
def __init__(self):
self._waiters = deque()
def acquired(self):
return len(self._waiters) > 0
@gen.coroutine
def acquire(self):
blocker = self._waiters[-1] if self.acquired() else None
future = concurrent.Future()
self._waiters.append(future)
if blocker:
yield blocker
raise gen.Return(self._lock_context())
def release(self):
assert self.acquired(), 'Lock not aquired'
future = self._waiters.popleft()
future.set_result(None)
@contextmanager
def _lock_context(self):
try:
yield
finally:
self.release()
class TTornadoStreamTransport(TTransportBase):
"""a framed, buffered transport over a Tornado stream"""
def __init__(self, host, port, stream=None, io_loop=None):
self.host = host
self.port = port
self.io_loop = io_loop or ioloop.IOLoop.current()
self.__wbuf = BytesIO()
self._read_lock = _Lock()
# servers provide a ready-to-go stream
self.stream = stream
def with_timeout(self, timeout, future):
return gen.with_timeout(timeout, future, self.io_loop)
@gen.coroutine
def open(self, timeout=None):
logger.debug('socket connecting')
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
self.stream = iostream.IOStream(sock)
try:
connect = self.stream.connect((self.host, self.port))
if timeout is not None:
yield self.with_timeout(timeout, connect)
else:
yield connect
except (socket.error, IOError, ioloop.TimeoutError) as e:
message = 'could not connect to {}:{} ({})'.format(self.host, self.port, e)
raise TTransportException(
type=TTransportException.NOT_OPEN,
message=message)
raise gen.Return(self)
def set_close_callback(self, callback):
"""
Should be called only after open() returns
"""
self.stream.set_close_callback(callback)
def close(self):
# don't raise if we intend to close
self.stream.set_close_callback(None)
self.stream.close()
def read(self, _):
# The generated code for Tornado shouldn't do individual reads -- only
# frames at a time
assert False, "you're doing it wrong"
@contextmanager
def io_exception_context(self):
try:
yield
except (socket.error, IOError) as e:
raise TTransportException(
type=TTransportException.END_OF_FILE,
message=str(e))
except iostream.StreamBufferFullError as e:
raise TTransportException(
type=TTransportException.UNKNOWN,
message=str(e))
@gen.coroutine
def readFrame(self):
# IOStream processes reads one at a time
with (yield self._read_lock.acquire()):
with self.io_exception_context():
frame_header = yield self.stream.read_bytes(4)
if len(frame_header) == 0:
raise iostream.StreamClosedError('Read zero bytes from stream')
frame_length, = struct.unpack('!i', frame_header)
frame = yield self.stream.read_bytes(frame_length)
raise gen.Return(frame)
def write(self, buf):
self.__wbuf.write(buf)
def flush(self):
frame = self.__wbuf.getvalue()
# reset wbuf before write/flush to preserve state on underlying failure
frame_length = struct.pack('!i', len(frame))
self.__wbuf = BytesIO()
with self.io_exception_context():
return self.stream.write(frame_length + frame)
class TTornadoServer(tcpserver.TCPServer):
def __init__(self, processor, iprot_factory, oprot_factory=None,
*args, **kwargs):
super(TTornadoServer, self).__init__(*args, **kwargs)
self._processor = processor
self._iprot_factory = iprot_factory
self._oprot_factory = (oprot_factory if oprot_factory is not None
else iprot_factory)
@gen.coroutine
def handle_stream(self, stream, address):
host, port = address[:2]
trans = TTornadoStreamTransport(host=host, port=port, stream=stream,
io_loop=self.io_loop)
oprot = self._oprot_factory.getProtocol(trans)
try:
while not trans.stream.closed():
try:
frame = yield trans.readFrame()
except TTransportException as e:
if e.type == TTransportException.END_OF_FILE:
break
else:
raise
tr = TMemoryBuffer(frame)
iprot = self._iprot_factory.getProtocol(tr)
yield self._processor.process(iprot, oprot)
except Exception:
logger.exception('thrift exception in handle_stream')
trans.close()
logger.info('client disconnected %s:%d', host, port)

192
vendor/git.apache.org/thrift.git/lib/py/src/Thrift.py generated vendored Normal file
View file

@ -0,0 +1,192 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
import sys
class TType(object):
STOP = 0
VOID = 1
BOOL = 2
BYTE = 3
I08 = 3
DOUBLE = 4
I16 = 6
I32 = 8
I64 = 10
STRING = 11
UTF7 = 11
STRUCT = 12
MAP = 13
SET = 14
LIST = 15
UTF8 = 16
UTF16 = 17
_VALUES_TO_NAMES = (
'STOP',
'VOID',
'BOOL',
'BYTE',
'DOUBLE',
None,
'I16',
None,
'I32',
None,
'I64',
'STRING',
'STRUCT',
'MAP',
'SET',
'LIST',
'UTF8',
'UTF16',
)
class TMessageType(object):
CALL = 1
REPLY = 2
EXCEPTION = 3
ONEWAY = 4
class TProcessor(object):
"""Base class for procsessor, which works on two streams."""
def process(iprot, oprot):
pass
class TException(Exception):
"""Base class for all thrift exceptions."""
# BaseException.message is deprecated in Python v[2.6,3.0)
if (2, 6, 0) <= sys.version_info < (3, 0):
def _get_message(self):
return self._message
def _set_message(self, message):
self._message = message
message = property(_get_message, _set_message)
def __init__(self, message=None):
Exception.__init__(self, message)
self.message = message
class TApplicationException(TException):
"""Application level thrift exceptions."""
UNKNOWN = 0
UNKNOWN_METHOD = 1
INVALID_MESSAGE_TYPE = 2
WRONG_METHOD_NAME = 3
BAD_SEQUENCE_ID = 4
MISSING_RESULT = 5
INTERNAL_ERROR = 6
PROTOCOL_ERROR = 7
INVALID_TRANSFORM = 8
INVALID_PROTOCOL = 9
UNSUPPORTED_CLIENT_TYPE = 10
def __init__(self, type=UNKNOWN, message=None):
TException.__init__(self, message)
self.type = type
def __str__(self):
if self.message:
return self.message
elif self.type == self.UNKNOWN_METHOD:
return 'Unknown method'
elif self.type == self.INVALID_MESSAGE_TYPE:
return 'Invalid message type'
elif self.type == self.WRONG_METHOD_NAME:
return 'Wrong method name'
elif self.type == self.BAD_SEQUENCE_ID:
return 'Bad sequence ID'
elif self.type == self.MISSING_RESULT:
return 'Missing result'
elif self.type == self.INTERNAL_ERROR:
return 'Internal error'
elif self.type == self.PROTOCOL_ERROR:
return 'Protocol error'
elif self.type == self.INVALID_TRANSFORM:
return 'Invalid transform'
elif self.type == self.INVALID_PROTOCOL:
return 'Invalid protocol'
elif self.type == self.UNSUPPORTED_CLIENT_TYPE:
return 'Unsupported client type'
else:
return 'Default (unknown) TApplicationException'
def read(self, iprot):
iprot.readStructBegin()
while True:
(fname, ftype, fid) = iprot.readFieldBegin()
if ftype == TType.STOP:
break
if fid == 1:
if ftype == TType.STRING:
self.message = iprot.readString()
else:
iprot.skip(ftype)
elif fid == 2:
if ftype == TType.I32:
self.type = iprot.readI32()
else:
iprot.skip(ftype)
else:
iprot.skip(ftype)
iprot.readFieldEnd()
iprot.readStructEnd()
def write(self, oprot):
oprot.writeStructBegin('TApplicationException')
if self.message is not None:
oprot.writeFieldBegin('message', TType.STRING, 1)
oprot.writeString(self.message)
oprot.writeFieldEnd()
if self.type is not None:
oprot.writeFieldBegin('type', TType.I32, 2)
oprot.writeI32(self.type)
oprot.writeFieldEnd()
oprot.writeFieldStop()
oprot.writeStructEnd()
class TFrozenDict(dict):
"""A dictionary that is "frozen" like a frozenset"""
def __init__(self, *args, **kwargs):
super(TFrozenDict, self).__init__(*args, **kwargs)
# Sort the items so they will be in a consistent order.
# XOR in the hash of the class so we don't collide with
# the hash of a list of tuples.
self.__hashval = hash(TFrozenDict) ^ hash(tuple(sorted(self.items())))
def __setitem__(self, *args):
raise TypeError("Can't modify frozen TFreezableDict")
def __delitem__(self, *args):
raise TypeError("Can't modify frozen TFreezableDict")
def __hash__(self):
return self.__hashval

View file

@ -0,0 +1,20 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
__all__ = ['Thrift', 'TSCons']

40
vendor/git.apache.org/thrift.git/lib/py/src/compat.py generated vendored Normal file
View file

@ -0,0 +1,40 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
import sys
if sys.version_info[0] == 2:
from cStringIO import StringIO as BufferIO
def binary_to_str(bin_val):
return bin_val
def str_to_binary(str_val):
return str_val
else:
from io import BytesIO as BufferIO # noqa
def binary_to_str(bin_val):
return bin_val.decode('utf8')
def str_to_binary(str_val):
return bytes(str_val, 'utf8')

View file

@ -0,0 +1,38 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "ext/binary.h"
namespace apache {
namespace thrift {
namespace py {
bool BinaryProtocol::readFieldBegin(TType& type, int16_t& tag) {
uint8_t b = 0;
if (!readByte(b)) {
return false;
}
type = static_cast<TType>(b);
if (type == T_STOP) {
return true;
}
return readI16(tag);
}
}
}
}

View file

@ -0,0 +1,214 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef THRIFT_PY_BINARY_H
#define THRIFT_PY_BINARY_H
#include <Python.h>
#include "ext/protocol.h"
#include "ext/endian.h"
#include <stdint.h>
namespace apache {
namespace thrift {
namespace py {
class BinaryProtocol : public ProtocolBase<BinaryProtocol> {
public:
virtual ~BinaryProtocol() {}
void writeI8(int8_t val) { writeBuffer(reinterpret_cast<char*>(&val), sizeof(int8_t)); }
void writeI16(int16_t val) {
int16_t net = static_cast<int16_t>(htons(val));
writeBuffer(reinterpret_cast<char*>(&net), sizeof(int16_t));
}
void writeI32(int32_t val) {
int32_t net = static_cast<int32_t>(htonl(val));
writeBuffer(reinterpret_cast<char*>(&net), sizeof(int32_t));
}
void writeI64(int64_t val) {
int64_t net = static_cast<int64_t>(htonll(val));
writeBuffer(reinterpret_cast<char*>(&net), sizeof(int64_t));
}
void writeDouble(double dub) {
// Unfortunately, bitwise_cast doesn't work in C. Bad C!
union {
double f;
int64_t t;
} transfer;
transfer.f = dub;
writeI64(transfer.t);
}
void writeBool(int v) { writeByte(static_cast<uint8_t>(v)); }
void writeString(PyObject* value, int32_t len) {
writeI32(len);
writeBuffer(PyBytes_AS_STRING(value), len);
}
bool writeListBegin(PyObject* value, const SetListTypeArgs& parsedargs, int32_t len) {
writeByte(parsedargs.element_type);
writeI32(len);
return true;
}
bool writeMapBegin(PyObject* value, const MapTypeArgs& parsedargs, int32_t len) {
writeByte(parsedargs.ktag);
writeByte(parsedargs.vtag);
writeI32(len);
return true;
}
bool writeStructBegin() { return true; }
bool writeStructEnd() { return true; }
bool writeField(PyObject* value, const StructItemSpec& parsedspec) {
writeByte(static_cast<uint8_t>(parsedspec.type));
writeI16(parsedspec.tag);
return encodeValue(value, parsedspec.type, parsedspec.typeargs);
}
void writeFieldStop() { writeByte(static_cast<uint8_t>(T_STOP)); }
bool readBool(bool& val) {
char* buf;
if (!readBytes(&buf, 1)) {
return false;
}
val = buf[0] == 1;
return true;
}
bool readI8(int8_t& val) {
char* buf;
if (!readBytes(&buf, 1)) {
return false;
}
val = buf[0];
return true;
}
bool readI16(int16_t& val) {
char* buf;
if (!readBytes(&buf, sizeof(int16_t))) {
return false;
}
val = static_cast<int16_t>(ntohs(*reinterpret_cast<int16_t*>(buf)));
return true;
}
bool readI32(int32_t& val) {
char* buf;
if (!readBytes(&buf, sizeof(int32_t))) {
return false;
}
val = static_cast<int32_t>(ntohl(*reinterpret_cast<int32_t*>(buf)));
return true;
}
bool readI64(int64_t& val) {
char* buf;
if (!readBytes(&buf, sizeof(int64_t))) {
return false;
}
val = static_cast<int64_t>(ntohll(*reinterpret_cast<int64_t*>(buf)));
return true;
}
bool readDouble(double& val) {
union {
int64_t f;
double t;
} transfer;
if (!readI64(transfer.f)) {
return false;
}
val = transfer.t;
return true;
}
int32_t readString(char** buf) {
int32_t len = 0;
if (!readI32(len) || !checkLengthLimit(len, stringLimit()) || !readBytes(buf, len)) {
return -1;
}
return len;
}
int32_t readListBegin(TType& etype) {
int32_t len;
uint8_t b = 0;
if (!readByte(b) || !readI32(len) || !checkLengthLimit(len, containerLimit())) {
return -1;
}
etype = static_cast<TType>(b);
return len;
}
int32_t readMapBegin(TType& ktype, TType& vtype) {
int32_t len;
uint8_t k, v;
if (!readByte(k) || !readByte(v) || !readI32(len) || !checkLengthLimit(len, containerLimit())) {
return -1;
}
ktype = static_cast<TType>(k);
vtype = static_cast<TType>(v);
return len;
}
bool readStructBegin() { return true; }
bool readStructEnd() { return true; }
bool readFieldBegin(TType& type, int16_t& tag);
#define SKIPBYTES(n) \
do { \
if (!readBytes(&dummy_buf_, (n))) { \
return false; \
} \
return true; \
} while (0)
bool skipBool() { SKIPBYTES(1); }
bool skipByte() { SKIPBYTES(1); }
bool skipI16() { SKIPBYTES(2); }
bool skipI32() { SKIPBYTES(4); }
bool skipI64() { SKIPBYTES(8); }
bool skipDouble() { SKIPBYTES(8); }
bool skipString() {
int32_t len;
if (!readI32(len)) {
return false;
}
SKIPBYTES(len);
}
#undef SKIPBYTES
private:
char* dummy_buf_;
};
}
}
}
#endif // THRIFT_PY_BINARY_H

View file

@ -0,0 +1,107 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "ext/compact.h"
namespace apache {
namespace thrift {
namespace py {
const uint8_t CompactProtocol::TTypeToCType[] = {
CT_STOP, // T_STOP
0, // unused
CT_BOOLEAN_TRUE, // T_BOOL
CT_BYTE, // T_BYTE
CT_DOUBLE, // T_DOUBLE
0, // unused
CT_I16, // T_I16
0, // unused
CT_I32, // T_I32
0, // unused
CT_I64, // T_I64
CT_BINARY, // T_STRING
CT_STRUCT, // T_STRUCT
CT_MAP, // T_MAP
CT_SET, // T_SET
CT_LIST, // T_LIST
};
bool CompactProtocol::readFieldBegin(TType& type, int16_t& tag) {
uint8_t b;
if (!readByte(b)) {
return false;
}
uint8_t ctype = b & 0xf;
type = getTType(ctype);
if (type == -1) {
return false;
} else if (type == T_STOP) {
tag = 0;
return true;
}
uint8_t diff = (b & 0xf0) >> 4;
if (diff) {
tag = readTags_.top() + diff;
} else if (!readI16(tag)) {
readTags_.top() = -1;
return false;
}
if (ctype == CT_BOOLEAN_FALSE || ctype == CT_BOOLEAN_TRUE) {
readBool_.exists = true;
readBool_.value = ctype == CT_BOOLEAN_TRUE;
}
readTags_.top() = tag;
return true;
}
TType CompactProtocol::getTType(uint8_t type) {
switch (type) {
case T_STOP:
return T_STOP;
case CT_BOOLEAN_FALSE:
case CT_BOOLEAN_TRUE:
return T_BOOL;
case CT_BYTE:
return T_BYTE;
case CT_I16:
return T_I16;
case CT_I32:
return T_I32;
case CT_I64:
return T_I64;
case CT_DOUBLE:
return T_DOUBLE;
case CT_BINARY:
return T_STRING;
case CT_LIST:
return T_LIST;
case CT_SET:
return T_SET;
case CT_MAP:
return T_MAP;
case CT_STRUCT:
return T_STRUCT;
default:
PyErr_Format(PyExc_TypeError, "don't know what type: %d", type);
return static_cast<TType>(-1);
}
}
}
}
}

View file

@ -0,0 +1,367 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef THRIFT_PY_COMPACT_H
#define THRIFT_PY_COMPACT_H
#include <Python.h>
#include "ext/protocol.h"
#include "ext/endian.h"
#include <stdint.h>
#include <stack>
namespace apache {
namespace thrift {
namespace py {
class CompactProtocol : public ProtocolBase<CompactProtocol> {
public:
CompactProtocol() { readBool_.exists = false; }
virtual ~CompactProtocol() {}
void writeI8(int8_t val) { writeBuffer(reinterpret_cast<char*>(&val), 1); }
void writeI16(int16_t val) { writeVarint(toZigZag(val)); }
int writeI32(int32_t val) { return writeVarint(toZigZag(val)); }
void writeI64(int64_t val) { writeVarint64(toZigZag64(val)); }
void writeDouble(double dub) {
union {
double f;
int64_t t;
} transfer;
transfer.f = htolell(dub);
writeBuffer(reinterpret_cast<char*>(&transfer.t), sizeof(int64_t));
}
void writeBool(int v) { writeByte(static_cast<uint8_t>(v ? CT_BOOLEAN_TRUE : CT_BOOLEAN_FALSE)); }
void writeString(PyObject* value, int32_t len) {
writeVarint(len);
writeBuffer(PyBytes_AS_STRING(value), len);
}
bool writeListBegin(PyObject* value, const SetListTypeArgs& args, int32_t len) {
int ctype = toCompactType(args.element_type);
if (len <= 14) {
writeByte(static_cast<uint8_t>(len << 4 | ctype));
} else {
writeByte(0xf0 | ctype);
writeVarint(len);
}
return true;
}
bool writeMapBegin(PyObject* value, const MapTypeArgs& args, int32_t len) {
if (len == 0) {
writeByte(0);
return true;
}
int ctype = toCompactType(args.ktag) << 4 | toCompactType(args.vtag);
writeVarint(len);
writeByte(ctype);
return true;
}
bool writeStructBegin() {
writeTags_.push(0);
return true;
}
bool writeStructEnd() {
writeTags_.pop();
return true;
}
bool writeField(PyObject* value, const StructItemSpec& spec) {
if (spec.type == T_BOOL) {
doWriteFieldBegin(spec, PyObject_IsTrue(value) ? CT_BOOLEAN_TRUE : CT_BOOLEAN_FALSE);
return true;
} else {
doWriteFieldBegin(spec, toCompactType(spec.type));
return encodeValue(value, spec.type, spec.typeargs);
}
}
void writeFieldStop() { writeByte(0); }
bool readBool(bool& val) {
if (readBool_.exists) {
readBool_.exists = false;
val = readBool_.value;
return true;
}
char* buf;
if (!readBytes(&buf, 1)) {
return false;
}
val = buf[0] == CT_BOOLEAN_TRUE;
return true;
}
bool readI8(int8_t& val) {
char* buf;
if (!readBytes(&buf, 1)) {
return false;
}
val = buf[0];
return true;
}
bool readI16(int16_t& val) {
uint16_t uval;
if (readVarint<uint16_t, 3>(uval)) {
val = fromZigZag<int16_t, uint16_t>(uval);
return true;
}
return false;
}
bool readI32(int32_t& val) {
uint32_t uval;
if (readVarint<uint32_t, 5>(uval)) {
val = fromZigZag<int32_t, uint32_t>(uval);
return true;
}
return false;
}
bool readI64(int64_t& val) {
uint64_t uval;
if (readVarint<uint64_t, 10>(uval)) {
val = fromZigZag<int64_t, uint64_t>(uval);
return true;
}
return false;
}
bool readDouble(double& val) {
union {
int64_t f;
double t;
} transfer;
char* buf;
if (!readBytes(&buf, 8)) {
return false;
}
transfer.f = letohll(*reinterpret_cast<int64_t*>(buf));
val = transfer.t;
return true;
}
int32_t readString(char** buf) {
uint32_t len;
if (!readVarint<uint32_t, 5>(len) || !checkLengthLimit(len, stringLimit())) {
return -1;
}
if (len == 0) {
return 0;
}
if (!readBytes(buf, len)) {
return -1;
}
return len;
}
int32_t readListBegin(TType& etype) {
uint8_t b;
if (!readByte(b)) {
return -1;
}
etype = getTType(b & 0xf);
if (etype == -1) {
return -1;
}
uint32_t len = (b >> 4) & 0xf;
if (len == 15 && !readVarint<uint32_t, 5>(len)) {
return -1;
}
if (!checkLengthLimit(len, containerLimit())) {
return -1;
}
return len;
}
int32_t readMapBegin(TType& ktype, TType& vtype) {
uint32_t len;
if (!readVarint<uint32_t, 5>(len) || !checkLengthLimit(len, containerLimit())) {
return -1;
}
if (len != 0) {
uint8_t kvType;
if (!readByte(kvType)) {
return -1;
}
ktype = getTType(kvType >> 4);
vtype = getTType(kvType & 0xf);
if (ktype == -1 || vtype == -1) {
return -1;
}
}
return len;
}
bool readStructBegin() {
readTags_.push(0);
return true;
}
bool readStructEnd() {
readTags_.pop();
return true;
}
bool readFieldBegin(TType& type, int16_t& tag);
bool skipBool() {
bool val;
return readBool(val);
}
#define SKIPBYTES(n) \
do { \
if (!readBytes(&dummy_buf_, (n))) { \
return false; \
} \
return true; \
} while (0)
bool skipByte() { SKIPBYTES(1); }
bool skipDouble() { SKIPBYTES(8); }
bool skipI16() {
int16_t val;
return readI16(val);
}
bool skipI32() {
int32_t val;
return readI32(val);
}
bool skipI64() {
int64_t val;
return readI64(val);
}
bool skipString() {
uint32_t len;
if (!readVarint<uint32_t, 5>(len)) {
return false;
}
SKIPBYTES(len);
}
#undef SKIPBYTES
private:
enum Types {
CT_STOP = 0x00,
CT_BOOLEAN_TRUE = 0x01,
CT_BOOLEAN_FALSE = 0x02,
CT_BYTE = 0x03,
CT_I16 = 0x04,
CT_I32 = 0x05,
CT_I64 = 0x06,
CT_DOUBLE = 0x07,
CT_BINARY = 0x08,
CT_LIST = 0x09,
CT_SET = 0x0A,
CT_MAP = 0x0B,
CT_STRUCT = 0x0C
};
static const uint8_t TTypeToCType[];
TType getTType(uint8_t type);
int toCompactType(TType type) {
int i = static_cast<int>(type);
return i < 16 ? TTypeToCType[i] : -1;
}
uint32_t toZigZag(int32_t val) { return (val >> 31) ^ (val << 1); }
uint64_t toZigZag64(int64_t val) { return (val >> 63) ^ (val << 1); }
int writeVarint(uint32_t val) {
int cnt = 1;
while (val & ~0x7fU) {
writeByte(static_cast<char>((val & 0x7fU) | 0x80U));
val >>= 7;
++cnt;
}
writeByte(static_cast<char>(val));
return cnt;
}
int writeVarint64(uint64_t val) {
int cnt = 1;
while (val & ~0x7fULL) {
writeByte(static_cast<char>((val & 0x7fULL) | 0x80ULL));
val >>= 7;
++cnt;
}
writeByte(static_cast<char>(val));
return cnt;
}
template <typename T, int Max>
bool readVarint(T& result) {
uint8_t b;
T val = 0;
int shift = 0;
for (int i = 0; i < Max; ++i) {
if (!readByte(b)) {
return false;
}
if (b & 0x80) {
val |= static_cast<T>(b & 0x7f) << shift;
} else {
val |= static_cast<T>(b) << shift;
result = val;
return true;
}
shift += 7;
}
PyErr_Format(PyExc_OverflowError, "varint exceeded %d bytes", Max);
return false;
}
template <typename S, typename U>
S fromZigZag(U val) {
return (val >> 1) ^ static_cast<U>(-static_cast<S>(val & 1));
}
void doWriteFieldBegin(const StructItemSpec& spec, int ctype) {
int diff = spec.tag - writeTags_.top();
if (diff > 0 && diff <= 15) {
writeByte(static_cast<uint8_t>(diff << 4 | ctype));
} else {
writeByte(static_cast<uint8_t>(ctype));
writeI16(spec.tag);
}
writeTags_.top() = spec.tag;
}
std::stack<int> writeTags_;
std::stack<int> readTags_;
struct {
bool exists;
bool value;
} readBool_;
char* dummy_buf_;
};
}
}
}
#endif // THRIFT_PY_COMPACT_H

View file

@ -0,0 +1,92 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef THRIFT_PY_ENDIAN_H
#define THRIFT_PY_ENDIAN_H
#include <Python.h>
#ifndef _WIN32
#include <netinet/in.h>
#else
#include <WinSock2.h>
#pragma comment(lib, "ws2_32.lib")
#define BIG_ENDIAN (4321)
#define LITTLE_ENDIAN (1234)
#define BYTE_ORDER LITTLE_ENDIAN
#define inline __inline
#endif
/* Fix endianness issues on Solaris */
#if defined(__SVR4) && defined(__sun)
#if defined(__i386) && !defined(__i386__)
#define __i386__
#endif
#ifndef BIG_ENDIAN
#define BIG_ENDIAN (4321)
#endif
#ifndef LITTLE_ENDIAN
#define LITTLE_ENDIAN (1234)
#endif
/* I386 is LE, even on Solaris */
#if !defined(BYTE_ORDER) && defined(__i386__)
#define BYTE_ORDER LITTLE_ENDIAN
#endif
#endif
#ifndef __BYTE_ORDER
#if defined(BYTE_ORDER) && defined(LITTLE_ENDIAN) && defined(BIG_ENDIAN)
#define __BYTE_ORDER BYTE_ORDER
#define __LITTLE_ENDIAN LITTLE_ENDIAN
#define __BIG_ENDIAN BIG_ENDIAN
#else
#error "Cannot determine endianness"
#endif
#endif
// Same comment as the enum. Sorry.
#if __BYTE_ORDER == __BIG_ENDIAN
#define ntohll(n) (n)
#define htonll(n) (n)
#if defined(__GNUC__) && defined(__GLIBC__)
#include <byteswap.h>
#define letohll(n) bswap_64(n)
#define htolell(n) bswap_64(n)
#else /* GNUC & GLIBC */
#define letohll(n) ((((unsigned long long)ntohl(n)) << 32) + ntohl(n >> 32))
#define htolell(n) ((((unsigned long long)htonl(n)) << 32) + htonl(n >> 32))
#endif
#elif __BYTE_ORDER == __LITTLE_ENDIAN
#if defined(__GNUC__) && defined(__GLIBC__)
#include <byteswap.h>
#define ntohll(n) bswap_64(n)
#define htonll(n) bswap_64(n)
#else /* GNUC & GLIBC */
#define ntohll(n) ((((unsigned long long)ntohl(n)) << 32) + ntohl(n >> 32))
#define htonll(n) ((((unsigned long long)htonl(n)) << 32) + htonl(n >> 32))
#endif /* GNUC & GLIBC */
#define letohll(n) (n)
#define htolell(n) (n)
#else /* __BYTE_ORDER */
#error "Can't define htonll or ntohll!"
#endif
#endif // THRIFT_PY_ENDIAN_H

View file

@ -0,0 +1,208 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <Python.h>
#include "types.h"
#include "binary.h"
#include "compact.h"
#include <limits>
#include <stdint.h>
// TODO(dreiss): defval appears to be unused. Look into removing it.
// TODO(dreiss): Make parse_spec_args recursive, and cache the output
// permanently in the object. (Malloc and orphan.)
// TODO(dreiss): Why do we need cStringIO for reading, why not just char*?
// Can cStringIO let us work with a BufferedTransport?
// TODO(dreiss): Don't ignore the rv from cwrite (maybe).
// Doing a benchmark shows that interning actually makes a difference, amazingly.
/** Pointer to interned string to speed up attribute lookup. */
PyObject* INTERN_STRING(TFrozenDict);
PyObject* INTERN_STRING(cstringio_buf);
PyObject* INTERN_STRING(cstringio_refill);
static PyObject* INTERN_STRING(string_length_limit);
static PyObject* INTERN_STRING(container_length_limit);
static PyObject* INTERN_STRING(trans);
namespace apache {
namespace thrift {
namespace py {
template <typename T>
static PyObject* encode_impl(PyObject* args) {
if (!args)
return NULL;
PyObject* enc_obj = NULL;
PyObject* type_args = NULL;
if (!PyArg_ParseTuple(args, "OO", &enc_obj, &type_args)) {
return NULL;
}
if (!enc_obj || !type_args) {
return NULL;
}
T protocol;
if (!protocol.prepareEncodeBuffer() || !protocol.encodeValue(enc_obj, T_STRUCT, type_args)) {
return NULL;
}
return protocol.getEncodedValue();
}
static inline long as_long_then_delete(PyObject* value, long default_value) {
ScopedPyObject scope(value);
long v = PyInt_AsLong(value);
if (INT_CONV_ERROR_OCCURRED(v)) {
PyErr_Clear();
return default_value;
}
return v;
}
template <typename T>
static PyObject* decode_impl(PyObject* args) {
PyObject* output_obj = NULL;
PyObject* oprot = NULL;
PyObject* typeargs = NULL;
if (!PyArg_ParseTuple(args, "OOO", &output_obj, &oprot, &typeargs)) {
return NULL;
}
T protocol;
#ifdef _MSC_VER
// workaround strange VC++ 2015 bug where #else path does not compile
int32_t default_limit = INT32_MAX;
#else
int32_t default_limit = std::numeric_limits<int32_t>::max();
#endif
protocol.setStringLengthLimit(
as_long_then_delete(PyObject_GetAttr(oprot, INTERN_STRING(string_length_limit)),
default_limit));
protocol.setContainerLengthLimit(
as_long_then_delete(PyObject_GetAttr(oprot, INTERN_STRING(container_length_limit)),
default_limit));
ScopedPyObject transport(PyObject_GetAttr(oprot, INTERN_STRING(trans)));
if (!transport) {
return NULL;
}
StructTypeArgs parsedargs;
if (!parse_struct_args(&parsedargs, typeargs)) {
return NULL;
}
if (!protocol.prepareDecodeBufferFromTransport(transport.get())) {
return NULL;
}
return protocol.readStruct(output_obj, parsedargs.klass, parsedargs.spec);
}
}
}
}
using namespace apache::thrift::py;
/* -- PYTHON MODULE SETUP STUFF --- */
extern "C" {
static PyObject* encode_binary(PyObject*, PyObject* args) {
return encode_impl<BinaryProtocol>(args);
}
static PyObject* decode_binary(PyObject*, PyObject* args) {
return decode_impl<BinaryProtocol>(args);
}
static PyObject* encode_compact(PyObject*, PyObject* args) {
return encode_impl<CompactProtocol>(args);
}
static PyObject* decode_compact(PyObject*, PyObject* args) {
return decode_impl<CompactProtocol>(args);
}
static PyMethodDef ThriftFastBinaryMethods[] = {
{"encode_binary", encode_binary, METH_VARARGS, ""},
{"decode_binary", decode_binary, METH_VARARGS, ""},
{"encode_compact", encode_compact, METH_VARARGS, ""},
{"decode_compact", decode_compact, METH_VARARGS, ""},
{NULL, NULL, 0, NULL} /* Sentinel */
};
#if PY_MAJOR_VERSION >= 3
static struct PyModuleDef ThriftFastBinaryDef = {PyModuleDef_HEAD_INIT,
"thrift.protocol.fastbinary",
NULL,
0,
ThriftFastBinaryMethods,
NULL,
NULL,
NULL,
NULL};
#define INITERROR return NULL;
PyObject* PyInit_fastbinary() {
#else
#define INITERROR return;
void initfastbinary() {
PycString_IMPORT;
if (PycStringIO == NULL)
INITERROR
#endif
#define INIT_INTERN_STRING(value) \
do { \
INTERN_STRING(value) = PyString_InternFromString(#value); \
if (!INTERN_STRING(value)) \
INITERROR \
} while (0)
INIT_INTERN_STRING(TFrozenDict);
INIT_INTERN_STRING(cstringio_buf);
INIT_INTERN_STRING(cstringio_refill);
INIT_INTERN_STRING(string_length_limit);
INIT_INTERN_STRING(container_length_limit);
INIT_INTERN_STRING(trans);
#undef INIT_INTERN_STRING
PyObject* module =
#if PY_MAJOR_VERSION >= 3
PyModule_Create(&ThriftFastBinaryDef);
#else
Py_InitModule("thrift.protocol.fastbinary", ThriftFastBinaryMethods);
#endif
if (module == NULL)
INITERROR;
#if PY_MAJOR_VERSION >= 3
return module;
#endif
}
}

View file

@ -0,0 +1,96 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef THRIFT_PY_PROTOCOL_H
#define THRIFT_PY_PROTOCOL_H
#include "ext/types.h"
#include <limits>
#include <stdint.h>
namespace apache {
namespace thrift {
namespace py {
template <typename Impl>
class ProtocolBase {
public:
ProtocolBase()
: stringLimit_(std::numeric_limits<int32_t>::max()),
containerLimit_(std::numeric_limits<int32_t>::max()),
output_(NULL) {}
inline virtual ~ProtocolBase();
bool prepareDecodeBufferFromTransport(PyObject* trans);
PyObject* readStruct(PyObject* output, PyObject* klass, PyObject* spec_seq);
bool prepareEncodeBuffer();
bool encodeValue(PyObject* value, TType type, PyObject* typeargs);
PyObject* getEncodedValue();
long stringLimit() const { return stringLimit_; }
void setStringLengthLimit(long limit) { stringLimit_ = limit; }
long containerLimit() const { return containerLimit_; }
void setContainerLengthLimit(long limit) { containerLimit_ = limit; }
protected:
bool readBytes(char** output, int len);
bool readByte(uint8_t& val) {
char* buf;
if (!readBytes(&buf, 1)) {
return false;
}
val = static_cast<uint8_t>(buf[0]);
return true;
}
bool writeBuffer(char* data, size_t len);
void writeByte(uint8_t val) { writeBuffer(reinterpret_cast<char*>(&val), 1); }
PyObject* decodeValue(TType type, PyObject* typeargs);
bool skip(TType type);
inline bool checkType(TType got, TType expected);
inline bool checkLengthLimit(int32_t len, long limit);
inline bool isUtf8(PyObject* typeargs);
private:
Impl* impl() { return static_cast<Impl*>(this); }
long stringLimit_;
long containerLimit_;
EncodeBuffer* output_;
DecodeBuffer input_;
};
}
}
}
#include "ext/protocol.tcc"
#endif // THRIFT_PY_PROTOCOL_H

View file

@ -0,0 +1,913 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef THRIFT_PY_PROTOCOL_TCC
#define THRIFT_PY_PROTOCOL_TCC
#include <iterator>
#define CHECK_RANGE(v, min, max) (((v) <= (max)) && ((v) >= (min)))
#define INIT_OUTBUF_SIZE 128
#if PY_MAJOR_VERSION < 3
#include <cStringIO.h>
#else
#include <algorithm>
#endif
namespace apache {
namespace thrift {
namespace py {
#if PY_MAJOR_VERSION < 3
namespace detail {
inline bool input_check(PyObject* input) {
return PycStringIO_InputCheck(input);
}
inline EncodeBuffer* new_encode_buffer(size_t size) {
if (!PycStringIO) {
PycString_IMPORT;
}
if (!PycStringIO) {
return NULL;
}
return PycStringIO->NewOutput(size);
}
inline int read_buffer(PyObject* buf, char** output, int len) {
if (!PycStringIO) {
PycString_IMPORT;
}
if (!PycStringIO) {
PyErr_SetString(PyExc_ImportError, "failed to import native cStringIO");
return -1;
}
return PycStringIO->cread(buf, output, len);
}
}
template <typename Impl>
inline ProtocolBase<Impl>::~ProtocolBase() {
if (output_) {
Py_CLEAR(output_);
}
}
template <typename Impl>
inline bool ProtocolBase<Impl>::isUtf8(PyObject* typeargs) {
return PyString_Check(typeargs) && !strncmp(PyString_AS_STRING(typeargs), "UTF8", 4);
}
template <typename Impl>
PyObject* ProtocolBase<Impl>::getEncodedValue() {
if (!PycStringIO) {
PycString_IMPORT;
}
if (!PycStringIO) {
return NULL;
}
return PycStringIO->cgetvalue(output_);
}
template <typename Impl>
inline bool ProtocolBase<Impl>::writeBuffer(char* data, size_t size) {
if (!PycStringIO) {
PycString_IMPORT;
}
if (!PycStringIO) {
PyErr_SetString(PyExc_ImportError, "failed to import native cStringIO");
return false;
}
int len = PycStringIO->cwrite(output_, data, size);
if (len < 0) {
PyErr_SetString(PyExc_IOError, "failed to write to cStringIO object");
return false;
}
if (len != size) {
PyErr_Format(PyExc_EOFError, "write length mismatch: expected %lu got %d", size, len);
return false;
}
return true;
}
#else
namespace detail {
inline bool input_check(PyObject* input) {
// TODO: Check for BytesIO type
return true;
}
inline EncodeBuffer* new_encode_buffer(size_t size) {
EncodeBuffer* buffer = new EncodeBuffer;
buffer->buf.reserve(size);
buffer->pos = 0;
return buffer;
}
struct bytesio {
PyObject_HEAD
#if PY_MINOR_VERSION < 5
char* buf;
#else
PyObject* buf;
#endif
Py_ssize_t pos;
Py_ssize_t string_size;
};
inline int read_buffer(PyObject* buf, char** output, int len) {
bytesio* buf2 = reinterpret_cast<bytesio*>(buf);
#if PY_MINOR_VERSION < 5
*output = buf2->buf + buf2->pos;
#else
*output = PyBytes_AS_STRING(buf2->buf) + buf2->pos;
#endif
Py_ssize_t pos0 = buf2->pos;
buf2->pos = std::min(buf2->pos + static_cast<Py_ssize_t>(len), buf2->string_size);
return static_cast<int>(buf2->pos - pos0);
}
}
template <typename Impl>
inline ProtocolBase<Impl>::~ProtocolBase() {
if (output_) {
delete output_;
}
}
template <typename Impl>
inline bool ProtocolBase<Impl>::isUtf8(PyObject* typeargs) {
// while condition for py2 is "arg == 'UTF8'", it should be "arg != 'BINARY'" for py3.
// HACK: check the length and don't bother reading the value
return !PyUnicode_Check(typeargs) || PyUnicode_GET_LENGTH(typeargs) != 6;
}
template <typename Impl>
PyObject* ProtocolBase<Impl>::getEncodedValue() {
return PyBytes_FromStringAndSize(output_->buf.data(), output_->buf.size());
}
template <typename Impl>
inline bool ProtocolBase<Impl>::writeBuffer(char* data, size_t size) {
size_t need = size + output_->pos;
if (output_->buf.capacity() < need) {
try {
output_->buf.reserve(need);
} catch (std::bad_alloc& ex) {
PyErr_SetString(PyExc_MemoryError, "Failed to allocate write buffer");
return false;
}
}
std::copy(data, data + size, std::back_inserter(output_->buf));
return true;
}
#endif
namespace detail {
#define DECLARE_OP_SCOPE(name, op) \
template <typename Impl> \
struct name##Scope { \
Impl* impl; \
bool valid; \
name##Scope(Impl* thiz) : impl(thiz), valid(impl->op##Begin()) {} \
~name##Scope() { \
if (valid) \
impl->op##End(); \
} \
operator bool() { return valid; } \
}; \
template <typename Impl, template <typename> class T> \
name##Scope<Impl> op##Scope(T<Impl>* thiz) { \
return name##Scope<Impl>(static_cast<Impl*>(thiz)); \
}
DECLARE_OP_SCOPE(WriteStruct, writeStruct)
DECLARE_OP_SCOPE(ReadStruct, readStruct)
#undef DECLARE_OP_SCOPE
inline bool check_ssize_t_32(Py_ssize_t len) {
// error from getting the int
if (INT_CONV_ERROR_OCCURRED(len)) {
return false;
}
if (!CHECK_RANGE(len, 0, std::numeric_limits<int32_t>::max())) {
PyErr_SetString(PyExc_OverflowError, "size out of range: exceeded INT32_MAX");
return false;
}
return true;
}
}
template <typename T>
bool parse_pyint(PyObject* o, T* ret, int32_t min, int32_t max) {
long val = PyInt_AsLong(o);
if (INT_CONV_ERROR_OCCURRED(val)) {
return false;
}
if (!CHECK_RANGE(val, min, max)) {
PyErr_SetString(PyExc_OverflowError, "int out of range");
return false;
}
*ret = static_cast<T>(val);
return true;
}
template <typename Impl>
inline bool ProtocolBase<Impl>::checkType(TType got, TType expected) {
if (expected != got) {
PyErr_SetString(PyExc_TypeError, "got wrong ttype while reading field");
return false;
}
return true;
}
template <typename Impl>
bool ProtocolBase<Impl>::checkLengthLimit(int32_t len, long limit) {
if (len < 0) {
PyErr_Format(PyExc_OverflowError, "negative length: %ld", limit);
return false;
}
if (len > limit) {
PyErr_Format(PyExc_OverflowError, "size exceeded specified limit: %ld", limit);
return false;
}
return true;
}
template <typename Impl>
bool ProtocolBase<Impl>::readBytes(char** output, int len) {
if (len < 0) {
PyErr_Format(PyExc_ValueError, "attempted to read negative length: %d", len);
return false;
}
// TODO(dreiss): Don't fear the malloc. Think about taking a copy of
// the partial read instead of forcing the transport
// to prepend it to its buffer.
int rlen = detail::read_buffer(input_.stringiobuf.get(), output, len);
if (rlen == len) {
return true;
} else if (rlen == -1) {
return false;
} else {
// using building functions as this is a rare codepath
ScopedPyObject newiobuf(PyObject_CallFunction(input_.refill_callable.get(), refill_signature,
*output, rlen, len, NULL));
if (!newiobuf) {
return false;
}
// must do this *AFTER* the call so that we don't deref the io buffer
input_.stringiobuf.reset(newiobuf.release());
rlen = detail::read_buffer(input_.stringiobuf.get(), output, len);
if (rlen == len) {
return true;
} else if (rlen == -1) {
return false;
} else {
// TODO(dreiss): This could be a valid code path for big binary blobs.
PyErr_SetString(PyExc_TypeError, "refill claimed to have refilled the buffer, but didn't!!");
return false;
}
}
}
template <typename Impl>
bool ProtocolBase<Impl>::prepareDecodeBufferFromTransport(PyObject* trans) {
if (input_.stringiobuf) {
PyErr_SetString(PyExc_ValueError, "decode buffer is already initialized");
return false;
}
ScopedPyObject stringiobuf(PyObject_GetAttr(trans, INTERN_STRING(cstringio_buf)));
if (!stringiobuf) {
return false;
}
if (!detail::input_check(stringiobuf.get())) {
PyErr_SetString(PyExc_TypeError, "expecting stringio input_");
return false;
}
ScopedPyObject refill_callable(PyObject_GetAttr(trans, INTERN_STRING(cstringio_refill)));
if (!refill_callable) {
return false;
}
if (!PyCallable_Check(refill_callable.get())) {
PyErr_SetString(PyExc_TypeError, "expecting callable");
return false;
}
input_.stringiobuf.swap(stringiobuf);
input_.refill_callable.swap(refill_callable);
return true;
}
template <typename Impl>
bool ProtocolBase<Impl>::prepareEncodeBuffer() {
output_ = detail::new_encode_buffer(INIT_OUTBUF_SIZE);
return output_ != NULL;
}
template <typename Impl>
bool ProtocolBase<Impl>::encodeValue(PyObject* value, TType type, PyObject* typeargs) {
/*
* Refcounting Strategy:
*
* We assume that elements of the thrift_spec tuple are not going to be
* mutated, so we don't ref count those at all. Other than that, we try to
* keep a reference to all the user-created objects while we work with them.
* encodeValue assumes that a reference is already held. The *caller* is
* responsible for handling references
*/
switch (type) {
case T_BOOL: {
int v = PyObject_IsTrue(value);
if (v == -1) {
return false;
}
impl()->writeBool(v);
return true;
}
case T_I08: {
int8_t val;
if (!parse_pyint(value, &val, std::numeric_limits<int8_t>::min(),
std::numeric_limits<int8_t>::max())) {
return false;
}
impl()->writeI8(val);
return true;
}
case T_I16: {
int16_t val;
if (!parse_pyint(value, &val, std::numeric_limits<int16_t>::min(),
std::numeric_limits<int16_t>::max())) {
return false;
}
impl()->writeI16(val);
return true;
}
case T_I32: {
int32_t val;
if (!parse_pyint(value, &val, std::numeric_limits<int32_t>::min(),
std::numeric_limits<int32_t>::max())) {
return false;
}
impl()->writeI32(val);
return true;
}
case T_I64: {
int64_t nval = PyLong_AsLongLong(value);
if (INT_CONV_ERROR_OCCURRED(nval)) {
return false;
}
if (!CHECK_RANGE(nval, std::numeric_limits<int64_t>::min(),
std::numeric_limits<int64_t>::max())) {
PyErr_SetString(PyExc_OverflowError, "int out of range");
return false;
}
impl()->writeI64(nval);
return true;
}
case T_DOUBLE: {
double nval = PyFloat_AsDouble(value);
if (nval == -1.0 && PyErr_Occurred()) {
return false;
}
impl()->writeDouble(nval);
return true;
}
case T_STRING: {
ScopedPyObject nval;
if (PyUnicode_Check(value)) {
nval.reset(PyUnicode_AsUTF8String(value));
if (!nval) {
return false;
}
} else {
Py_INCREF(value);
nval.reset(value);
}
Py_ssize_t len = PyBytes_Size(nval.get());
if (!detail::check_ssize_t_32(len)) {
return false;
}
impl()->writeString(nval.get(), static_cast<int32_t>(len));
return true;
}
case T_LIST:
case T_SET: {
SetListTypeArgs parsedargs;
if (!parse_set_list_args(&parsedargs, typeargs)) {
return false;
}
Py_ssize_t len = PyObject_Length(value);
if (!detail::check_ssize_t_32(len)) {
return false;
}
if (!impl()->writeListBegin(value, parsedargs, static_cast<int32_t>(len)) || PyErr_Occurred()) {
return false;
}
ScopedPyObject iterator(PyObject_GetIter(value));
if (!iterator) {
return false;
}
while (PyObject* rawItem = PyIter_Next(iterator.get())) {
ScopedPyObject item(rawItem);
if (!encodeValue(item.get(), parsedargs.element_type, parsedargs.typeargs)) {
return false;
}
}
return true;
}
case T_MAP: {
Py_ssize_t len = PyDict_Size(value);
if (!detail::check_ssize_t_32(len)) {
return false;
}
MapTypeArgs parsedargs;
if (!parse_map_args(&parsedargs, typeargs)) {
return false;
}
if (!impl()->writeMapBegin(value, parsedargs, static_cast<int32_t>(len)) || PyErr_Occurred()) {
return false;
}
Py_ssize_t pos = 0;
PyObject* k = NULL;
PyObject* v = NULL;
// TODO(bmaurer): should support any mapping, not just dicts
while (PyDict_Next(value, &pos, &k, &v)) {
if (!encodeValue(k, parsedargs.ktag, parsedargs.ktypeargs)
|| !encodeValue(v, parsedargs.vtag, parsedargs.vtypeargs)) {
return false;
}
}
return true;
}
case T_STRUCT: {
StructTypeArgs parsedargs;
if (!parse_struct_args(&parsedargs, typeargs)) {
return false;
}
Py_ssize_t nspec = PyTuple_Size(parsedargs.spec);
if (nspec == -1) {
PyErr_SetString(PyExc_TypeError, "spec is not a tuple");
return false;
}
detail::WriteStructScope<Impl> scope = detail::writeStructScope(this);
if (!scope) {
return false;
}
for (Py_ssize_t i = 0; i < nspec; i++) {
PyObject* spec_tuple = PyTuple_GET_ITEM(parsedargs.spec, i);
if (spec_tuple == Py_None) {
continue;
}
StructItemSpec parsedspec;
if (!parse_struct_item_spec(&parsedspec, spec_tuple)) {
return false;
}
ScopedPyObject instval(PyObject_GetAttr(value, parsedspec.attrname));
if (!instval) {
return false;
}
if (instval.get() == Py_None) {
continue;
}
bool res = impl()->writeField(instval.get(), parsedspec);
if (!res) {
return false;
}
}
impl()->writeFieldStop();
return true;
}
case T_STOP:
case T_VOID:
case T_UTF16:
case T_UTF8:
case T_U64:
default:
PyErr_Format(PyExc_TypeError, "Unexpected TType for encodeValue: %d", type);
return false;
}
return true;
}
template <typename Impl>
bool ProtocolBase<Impl>::skip(TType type) {
switch (type) {
case T_BOOL:
return impl()->skipBool();
case T_I08:
return impl()->skipByte();
case T_I16:
return impl()->skipI16();
case T_I32:
return impl()->skipI32();
case T_I64:
return impl()->skipI64();
case T_DOUBLE:
return impl()->skipDouble();
case T_STRING: {
return impl()->skipString();
}
case T_LIST:
case T_SET: {
TType etype = T_STOP;
int32_t len = impl()->readListBegin(etype);
if (len < 0) {
return false;
}
for (int32_t i = 0; i < len; i++) {
if (!skip(etype)) {
return false;
}
}
return true;
}
case T_MAP: {
TType ktype = T_STOP;
TType vtype = T_STOP;
int32_t len = impl()->readMapBegin(ktype, vtype);
if (len < 0) {
return false;
}
for (int32_t i = 0; i < len; i++) {
if (!skip(ktype) || !skip(vtype)) {
return false;
}
}
return true;
}
case T_STRUCT: {
detail::ReadStructScope<Impl> scope = detail::readStructScope(this);
if (!scope) {
return false;
}
while (true) {
TType type = T_STOP;
int16_t tag;
if (!impl()->readFieldBegin(type, tag)) {
return false;
}
if (type == T_STOP) {
return true;
}
if (!skip(type)) {
return false;
}
}
return true;
}
case T_STOP:
case T_VOID:
case T_UTF16:
case T_UTF8:
case T_U64:
default:
PyErr_Format(PyExc_TypeError, "Unexpected TType for skip: %d", type);
return false;
}
return true;
}
// Returns a new reference.
template <typename Impl>
PyObject* ProtocolBase<Impl>::decodeValue(TType type, PyObject* typeargs) {
switch (type) {
case T_BOOL: {
bool v = 0;
if (!impl()->readBool(v)) {
return NULL;
}
if (v) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
}
}
case T_I08: {
int8_t v = 0;
if (!impl()->readI8(v)) {
return NULL;
}
return PyInt_FromLong(v);
}
case T_I16: {
int16_t v = 0;
if (!impl()->readI16(v)) {
return NULL;
}
return PyInt_FromLong(v);
}
case T_I32: {
int32_t v = 0;
if (!impl()->readI32(v)) {
return NULL;
}
return PyInt_FromLong(v);
}
case T_I64: {
int64_t v = 0;
if (!impl()->readI64(v)) {
return NULL;
}
// TODO(dreiss): Find out if we can take this fastpath always when
// sizeof(long) == sizeof(long long).
if (CHECK_RANGE(v, LONG_MIN, LONG_MAX)) {
return PyInt_FromLong((long)v);
}
return PyLong_FromLongLong(v);
}
case T_DOUBLE: {
double v = 0.0;
if (!impl()->readDouble(v)) {
return NULL;
}
return PyFloat_FromDouble(v);
}
case T_STRING: {
char* buf = NULL;
int len = impl()->readString(&buf);
if (len < 0) {
return NULL;
}
if (isUtf8(typeargs)) {
return PyUnicode_DecodeUTF8(buf, len, 0);
} else {
return PyBytes_FromStringAndSize(buf, len);
}
}
case T_LIST:
case T_SET: {
SetListTypeArgs parsedargs;
if (!parse_set_list_args(&parsedargs, typeargs)) {
return NULL;
}
TType etype = T_STOP;
int32_t len = impl()->readListBegin(etype);
if (len < 0) {
return NULL;
}
if (len > 0 && !checkType(etype, parsedargs.element_type)) {
return NULL;
}
bool use_tuple = type == T_LIST && parsedargs.immutable;
ScopedPyObject ret(use_tuple ? PyTuple_New(len) : PyList_New(len));
if (!ret) {
return NULL;
}
for (int i = 0; i < len; i++) {
PyObject* item = decodeValue(etype, parsedargs.typeargs);
if (!item) {
return NULL;
}
if (use_tuple) {
PyTuple_SET_ITEM(ret.get(), i, item);
} else {
PyList_SET_ITEM(ret.get(), i, item);
}
}
// TODO(dreiss): Consider biting the bullet and making two separate cases
// for list and set, avoiding this post facto conversion.
if (type == T_SET) {
PyObject* setret;
setret = parsedargs.immutable ? PyFrozenSet_New(ret.get()) : PySet_New(ret.get());
return setret;
}
return ret.release();
}
case T_MAP: {
MapTypeArgs parsedargs;
if (!parse_map_args(&parsedargs, typeargs)) {
return NULL;
}
TType ktype = T_STOP;
TType vtype = T_STOP;
uint32_t len = impl()->readMapBegin(ktype, vtype);
if (len > 0 && (!checkType(ktype, parsedargs.ktag) || !checkType(vtype, parsedargs.vtag))) {
return NULL;
}
ScopedPyObject ret(PyDict_New());
if (!ret) {
return NULL;
}
for (uint32_t i = 0; i < len; i++) {
ScopedPyObject k(decodeValue(ktype, parsedargs.ktypeargs));
if (!k) {
return NULL;
}
ScopedPyObject v(decodeValue(vtype, parsedargs.vtypeargs));
if (!v) {
return NULL;
}
if (PyDict_SetItem(ret.get(), k.get(), v.get()) == -1) {
return NULL;
}
}
if (parsedargs.immutable) {
if (!ThriftModule) {
ThriftModule = PyImport_ImportModule("thrift.Thrift");
}
if (!ThriftModule) {
return NULL;
}
ScopedPyObject cls(PyObject_GetAttr(ThriftModule, INTERN_STRING(TFrozenDict)));
if (!cls) {
return NULL;
}
ScopedPyObject arg(PyTuple_New(1));
PyTuple_SET_ITEM(arg.get(), 0, ret.release());
ret.reset(PyObject_CallObject(cls.get(), arg.get()));
}
return ret.release();
}
case T_STRUCT: {
StructTypeArgs parsedargs;
if (!parse_struct_args(&parsedargs, typeargs)) {
return NULL;
}
return readStruct(Py_None, parsedargs.klass, parsedargs.spec);
}
case T_STOP:
case T_VOID:
case T_UTF16:
case T_UTF8:
case T_U64:
default:
PyErr_Format(PyExc_TypeError, "Unexpected TType for decodeValue: %d", type);
return NULL;
}
}
template <typename Impl>
PyObject* ProtocolBase<Impl>::readStruct(PyObject* output, PyObject* klass, PyObject* spec_seq) {
int spec_seq_len = PyTuple_Size(spec_seq);
bool immutable = output == Py_None;
ScopedPyObject kwargs;
if (spec_seq_len == -1) {
return NULL;
}
if (immutable) {
kwargs.reset(PyDict_New());
if (!kwargs) {
PyErr_SetString(PyExc_TypeError, "failed to prepare kwargument storage");
return NULL;
}
}
detail::ReadStructScope<Impl> scope = detail::readStructScope(this);
if (!scope) {
return NULL;
}
while (true) {
TType type = T_STOP;
int16_t tag;
if (!impl()->readFieldBegin(type, tag)) {
return NULL;
}
if (type == T_STOP) {
break;
}
if (tag < 0 || tag >= spec_seq_len) {
if (!skip(type)) {
PyErr_SetString(PyExc_TypeError, "Error while skipping unknown field");
return NULL;
}
continue;
}
PyObject* item_spec = PyTuple_GET_ITEM(spec_seq, tag);
if (item_spec == Py_None) {
if (!skip(type)) {
PyErr_SetString(PyExc_TypeError, "Error while skipping unknown field");
return NULL;
}
continue;
}
StructItemSpec parsedspec;
if (!parse_struct_item_spec(&parsedspec, item_spec)) {
return NULL;
}
if (parsedspec.type != type) {
if (!skip(type)) {
PyErr_Format(PyExc_TypeError, "struct field had wrong type: expected %d but got %d",
parsedspec.type, type);
return NULL;
}
continue;
}
ScopedPyObject fieldval(decodeValue(parsedspec.type, parsedspec.typeargs));
if (!fieldval) {
return NULL;
}
if ((immutable && PyDict_SetItem(kwargs.get(), parsedspec.attrname, fieldval.get()) == -1)
|| (!immutable && PyObject_SetAttr(output, parsedspec.attrname, fieldval.get()) == -1)) {
return NULL;
}
}
if (immutable) {
ScopedPyObject args(PyTuple_New(0));
if (!args) {
PyErr_SetString(PyExc_TypeError, "failed to prepare argument storage");
return NULL;
}
return PyObject_Call(klass, args.get(), kwargs.get());
}
Py_INCREF(output);
return output;
}
}
}
}
#endif // THRIFT_PY_PROTOCOL_H

View file

@ -0,0 +1,113 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "ext/types.h"
#include "ext/protocol.h"
namespace apache {
namespace thrift {
namespace py {
PyObject* ThriftModule = NULL;
#if PY_MAJOR_VERSION < 3
char refill_signature[] = {'s', '#', 'i'};
#else
const char* refill_signature = "y#i";
#endif
bool parse_struct_item_spec(StructItemSpec* dest, PyObject* spec_tuple) {
// i'd like to use ParseArgs here, but it seems to be a bottleneck.
if (PyTuple_Size(spec_tuple) != 5) {
PyErr_Format(PyExc_TypeError, "expecting 5 arguments for spec tuple but got %d",
static_cast<int>(PyTuple_Size(spec_tuple)));
return false;
}
dest->tag = static_cast<TType>(PyInt_AsLong(PyTuple_GET_ITEM(spec_tuple, 0)));
if (INT_CONV_ERROR_OCCURRED(dest->tag)) {
return false;
}
dest->type = static_cast<TType>(PyInt_AsLong(PyTuple_GET_ITEM(spec_tuple, 1)));
if (INT_CONV_ERROR_OCCURRED(dest->type)) {
return false;
}
dest->attrname = PyTuple_GET_ITEM(spec_tuple, 2);
dest->typeargs = PyTuple_GET_ITEM(spec_tuple, 3);
dest->defval = PyTuple_GET_ITEM(spec_tuple, 4);
return true;
}
bool parse_set_list_args(SetListTypeArgs* dest, PyObject* typeargs) {
if (PyTuple_Size(typeargs) != 3) {
PyErr_SetString(PyExc_TypeError, "expecting tuple of size 3 for list/set type args");
return false;
}
dest->element_type = static_cast<TType>(PyInt_AsLong(PyTuple_GET_ITEM(typeargs, 0)));
if (INT_CONV_ERROR_OCCURRED(dest->element_type)) {
return false;
}
dest->typeargs = PyTuple_GET_ITEM(typeargs, 1);
dest->immutable = Py_True == PyTuple_GET_ITEM(typeargs, 2);
return true;
}
bool parse_map_args(MapTypeArgs* dest, PyObject* typeargs) {
if (PyTuple_Size(typeargs) != 5) {
PyErr_SetString(PyExc_TypeError, "expecting 5 arguments for typeargs to map");
return false;
}
dest->ktag = static_cast<TType>(PyInt_AsLong(PyTuple_GET_ITEM(typeargs, 0)));
if (INT_CONV_ERROR_OCCURRED(dest->ktag)) {
return false;
}
dest->vtag = static_cast<TType>(PyInt_AsLong(PyTuple_GET_ITEM(typeargs, 2)));
if (INT_CONV_ERROR_OCCURRED(dest->vtag)) {
return false;
}
dest->ktypeargs = PyTuple_GET_ITEM(typeargs, 1);
dest->vtypeargs = PyTuple_GET_ITEM(typeargs, 3);
dest->immutable = Py_True == PyTuple_GET_ITEM(typeargs, 4);
return true;
}
bool parse_struct_args(StructTypeArgs* dest, PyObject* typeargs) {
if (PyTuple_Size(typeargs) != 2) {
PyErr_SetString(PyExc_TypeError, "expecting tuple of size 2 for struct args");
return false;
}
dest->klass = PyTuple_GET_ITEM(typeargs, 0);
dest->spec = PyTuple_GET_ITEM(typeargs, 1);
return true;
}
}
}
}

191
vendor/git.apache.org/thrift.git/lib/py/src/ext/types.h generated vendored Normal file
View file

@ -0,0 +1,191 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef THRIFT_PY_TYPES_H
#define THRIFT_PY_TYPES_H
#include <Python.h>
#ifdef _MSC_VER
#define __STDC_LIMIT_MACROS
#endif
#include <stdint.h>
#if PY_MAJOR_VERSION >= 3
#include <vector>
// TODO: better macros
#define PyInt_AsLong(v) PyLong_AsLong(v)
#define PyInt_FromLong(v) PyLong_FromLong(v)
#define PyString_InternFromString(v) PyUnicode_InternFromString(v)
#endif
#define INTERN_STRING(value) _intern_##value
#define INT_CONV_ERROR_OCCURRED(v) (((v) == -1) && PyErr_Occurred())
extern "C" {
extern PyObject* INTERN_STRING(TFrozenDict);
extern PyObject* INTERN_STRING(cstringio_buf);
extern PyObject* INTERN_STRING(cstringio_refill);
}
namespace apache {
namespace thrift {
namespace py {
extern PyObject* ThriftModule;
// Stolen out of TProtocol.h.
// It would be a huge pain to have both get this from one place.
enum TType {
T_INVALID = -1,
T_STOP = 0,
T_VOID = 1,
T_BOOL = 2,
T_BYTE = 3,
T_I08 = 3,
T_I16 = 6,
T_I32 = 8,
T_U64 = 9,
T_I64 = 10,
T_DOUBLE = 4,
T_STRING = 11,
T_UTF7 = 11,
T_STRUCT = 12,
T_MAP = 13,
T_SET = 14,
T_LIST = 15,
T_UTF8 = 16,
T_UTF16 = 17
};
// replace with unique_ptr when we're OK with C++11
class ScopedPyObject {
public:
ScopedPyObject() : obj_(NULL) {}
explicit ScopedPyObject(PyObject* py_object) : obj_(py_object) {}
~ScopedPyObject() {
if (obj_)
Py_DECREF(obj_);
}
PyObject* get() throw() { return obj_; }
operator bool() { return obj_; }
void reset(PyObject* py_object) throw() {
if (obj_)
Py_DECREF(obj_);
obj_ = py_object;
}
PyObject* release() throw() {
PyObject* tmp = obj_;
obj_ = NULL;
return tmp;
}
void swap(ScopedPyObject& other) throw() {
ScopedPyObject tmp(other.release());
other.reset(release());
reset(tmp.release());
}
private:
ScopedPyObject(const ScopedPyObject&) {}
ScopedPyObject& operator=(const ScopedPyObject&) { return *this; }
PyObject* obj_;
};
/**
* A cache of the two key attributes of a CReadableTransport,
* so we don't have to keep calling PyObject_GetAttr.
*/
struct DecodeBuffer {
ScopedPyObject stringiobuf;
ScopedPyObject refill_callable;
};
#if PY_MAJOR_VERSION < 3
extern char refill_signature[3];
typedef PyObject EncodeBuffer;
#else
extern const char* refill_signature;
struct EncodeBuffer {
std::vector<char> buf;
size_t pos;
};
#endif
/**
* A cache of the spec_args for a set or list,
* so we don't have to keep calling PyTuple_GET_ITEM.
*/
struct SetListTypeArgs {
TType element_type;
PyObject* typeargs;
bool immutable;
};
/**
* A cache of the spec_args for a map,
* so we don't have to keep calling PyTuple_GET_ITEM.
*/
struct MapTypeArgs {
TType ktag;
TType vtag;
PyObject* ktypeargs;
PyObject* vtypeargs;
bool immutable;
};
/**
* A cache of the spec_args for a struct,
* so we don't have to keep calling PyTuple_GET_ITEM.
*/
struct StructTypeArgs {
PyObject* klass;
PyObject* spec;
bool immutable;
};
/**
* A cache of the item spec from a struct specification,
* so we don't have to keep calling PyTuple_GET_ITEM.
*/
struct StructItemSpec {
int tag;
TType type;
PyObject* attrname;
PyObject* typeargs;
PyObject* defval;
};
bool parse_set_list_args(SetListTypeArgs* dest, PyObject* typeargs);
bool parse_map_args(MapTypeArgs* dest, PyObject* typeargs);
bool parse_struct_args(StructTypeArgs* dest, PyObject* typeargs);
bool parse_struct_item_spec(StructItemSpec* dest, PyObject* spec_tuple);
}
}
}
#endif // THRIFT_PY_TYPES_H

View file

@ -0,0 +1,82 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
from thrift.transport import TTransport
class TBase(object):
__slots__ = ()
def __repr__(self):
L = ['%s=%r' % (key, getattr(self, key)) for key in self.__slots__]
return '%s(%s)' % (self.__class__.__name__, ', '.join(L))
def __eq__(self, other):
if not isinstance(other, self.__class__):
return False
for attr in self.__slots__:
my_val = getattr(self, attr)
other_val = getattr(other, attr)
if my_val != other_val:
return False
return True
def __ne__(self, other):
return not (self == other)
def read(self, iprot):
if (iprot._fast_decode is not None and
isinstance(iprot.trans, TTransport.CReadableTransport) and
self.thrift_spec is not None):
iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec))
else:
iprot.readStruct(self, self.thrift_spec)
def write(self, oprot):
if (oprot._fast_encode is not None and self.thrift_spec is not None):
oprot.trans.write(
oprot._fast_encode(self, (self.__class__, self.thrift_spec)))
else:
oprot.writeStruct(self, self.thrift_spec)
class TExceptionBase(TBase, Exception):
pass
class TFrozenBase(TBase):
def __setitem__(self, *args):
raise TypeError("Can't modify frozen struct")
def __delitem__(self, *args):
raise TypeError("Can't modify frozen struct")
def __hash__(self, *args):
return hash(self.__class__) ^ hash(self.__slots__)
@classmethod
def read(cls, iprot):
if (iprot._fast_decode is not None and
isinstance(iprot.trans, TTransport.CReadableTransport) and
cls.thrift_spec is not None):
self = cls()
return iprot._fast_decode(None, iprot,
(self.__class__, self.thrift_spec))
else:
return iprot.readStruct(cls, cls.thrift_spec, True)

View file

@ -0,0 +1,301 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
from .TProtocol import TType, TProtocolBase, TProtocolException
from struct import pack, unpack
class TBinaryProtocol(TProtocolBase):
"""Binary implementation of the Thrift protocol driver."""
# NastyHaxx. Python 2.4+ on 32-bit machines forces hex constants to be
# positive, converting this into a long. If we hardcode the int value
# instead it'll stay in 32 bit-land.
# VERSION_MASK = 0xffff0000
VERSION_MASK = -65536
# VERSION_1 = 0x80010000
VERSION_1 = -2147418112
TYPE_MASK = 0x000000ff
def __init__(self, trans, strictRead=False, strictWrite=True, **kwargs):
TProtocolBase.__init__(self, trans)
self.strictRead = strictRead
self.strictWrite = strictWrite
self.string_length_limit = kwargs.get('string_length_limit', None)
self.container_length_limit = kwargs.get('container_length_limit', None)
def _check_string_length(self, length):
self._check_length(self.string_length_limit, length)
def _check_container_length(self, length):
self._check_length(self.container_length_limit, length)
def writeMessageBegin(self, name, type, seqid):
if self.strictWrite:
self.writeI32(TBinaryProtocol.VERSION_1 | type)
self.writeString(name)
self.writeI32(seqid)
else:
self.writeString(name)
self.writeByte(type)
self.writeI32(seqid)
def writeMessageEnd(self):
pass
def writeStructBegin(self, name):
pass
def writeStructEnd(self):
pass
def writeFieldBegin(self, name, type, id):
self.writeByte(type)
self.writeI16(id)
def writeFieldEnd(self):
pass
def writeFieldStop(self):
self.writeByte(TType.STOP)
def writeMapBegin(self, ktype, vtype, size):
self.writeByte(ktype)
self.writeByte(vtype)
self.writeI32(size)
def writeMapEnd(self):
pass
def writeListBegin(self, etype, size):
self.writeByte(etype)
self.writeI32(size)
def writeListEnd(self):
pass
def writeSetBegin(self, etype, size):
self.writeByte(etype)
self.writeI32(size)
def writeSetEnd(self):
pass
def writeBool(self, bool):
if bool:
self.writeByte(1)
else:
self.writeByte(0)
def writeByte(self, byte):
buff = pack("!b", byte)
self.trans.write(buff)
def writeI16(self, i16):
buff = pack("!h", i16)
self.trans.write(buff)
def writeI32(self, i32):
buff = pack("!i", i32)
self.trans.write(buff)
def writeI64(self, i64):
buff = pack("!q", i64)
self.trans.write(buff)
def writeDouble(self, dub):
buff = pack("!d", dub)
self.trans.write(buff)
def writeBinary(self, str):
self.writeI32(len(str))
self.trans.write(str)
def readMessageBegin(self):
sz = self.readI32()
if sz < 0:
version = sz & TBinaryProtocol.VERSION_MASK
if version != TBinaryProtocol.VERSION_1:
raise TProtocolException(
type=TProtocolException.BAD_VERSION,
message='Bad version in readMessageBegin: %d' % (sz))
type = sz & TBinaryProtocol.TYPE_MASK
name = self.readString()
seqid = self.readI32()
else:
if self.strictRead:
raise TProtocolException(type=TProtocolException.BAD_VERSION,
message='No protocol version header')
name = self.trans.readAll(sz)
type = self.readByte()
seqid = self.readI32()
return (name, type, seqid)
def readMessageEnd(self):
pass
def readStructBegin(self):
pass
def readStructEnd(self):
pass
def readFieldBegin(self):
type = self.readByte()
if type == TType.STOP:
return (None, type, 0)
id = self.readI16()
return (None, type, id)
def readFieldEnd(self):
pass
def readMapBegin(self):
ktype = self.readByte()
vtype = self.readByte()
size = self.readI32()
self._check_container_length(size)
return (ktype, vtype, size)
def readMapEnd(self):
pass
def readListBegin(self):
etype = self.readByte()
size = self.readI32()
self._check_container_length(size)
return (etype, size)
def readListEnd(self):
pass
def readSetBegin(self):
etype = self.readByte()
size = self.readI32()
self._check_container_length(size)
return (etype, size)
def readSetEnd(self):
pass
def readBool(self):
byte = self.readByte()
if byte == 0:
return False
return True
def readByte(self):
buff = self.trans.readAll(1)
val, = unpack('!b', buff)
return val
def readI16(self):
buff = self.trans.readAll(2)
val, = unpack('!h', buff)
return val
def readI32(self):
buff = self.trans.readAll(4)
val, = unpack('!i', buff)
return val
def readI64(self):
buff = self.trans.readAll(8)
val, = unpack('!q', buff)
return val
def readDouble(self):
buff = self.trans.readAll(8)
val, = unpack('!d', buff)
return val
def readBinary(self):
size = self.readI32()
self._check_string_length(size)
s = self.trans.readAll(size)
return s
class TBinaryProtocolFactory(object):
def __init__(self, strictRead=False, strictWrite=True, **kwargs):
self.strictRead = strictRead
self.strictWrite = strictWrite
self.string_length_limit = kwargs.get('string_length_limit', None)
self.container_length_limit = kwargs.get('container_length_limit', None)
def getProtocol(self, trans):
prot = TBinaryProtocol(trans, self.strictRead, self.strictWrite,
string_length_limit=self.string_length_limit,
container_length_limit=self.container_length_limit)
return prot
class TBinaryProtocolAccelerated(TBinaryProtocol):
"""C-Accelerated version of TBinaryProtocol.
This class does not override any of TBinaryProtocol's methods,
but the generated code recognizes it directly and will call into
our C module to do the encoding, bypassing this object entirely.
We inherit from TBinaryProtocol so that the normal TBinaryProtocol
encoding can happen if the fastbinary module doesn't work for some
reason. (TODO(dreiss): Make this happen sanely in more cases.)
To disable this behavior, pass fallback=False constructor argument.
In order to take advantage of the C module, just use
TBinaryProtocolAccelerated instead of TBinaryProtocol.
NOTE: This code was contributed by an external developer.
The internal Thrift team has reviewed and tested it,
but we cannot guarantee that it is production-ready.
Please feel free to report bugs and/or success stories
to the public mailing list.
"""
pass
def __init__(self, *args, **kwargs):
fallback = kwargs.pop('fallback', True)
super(TBinaryProtocolAccelerated, self).__init__(*args, **kwargs)
try:
from thrift.protocol import fastbinary
except ImportError:
if not fallback:
raise
else:
self._fast_decode = fastbinary.decode_binary
self._fast_encode = fastbinary.encode_binary
class TBinaryProtocolAcceleratedFactory(object):
def __init__(self,
string_length_limit=None,
container_length_limit=None,
fallback=True):
self.string_length_limit = string_length_limit
self.container_length_limit = container_length_limit
self._fallback = fallback
def getProtocol(self, trans):
return TBinaryProtocolAccelerated(
trans,
string_length_limit=self.string_length_limit,
container_length_limit=self.container_length_limit,
fallback=self._fallback)

View file

@ -0,0 +1,472 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
from .TProtocol import TType, TProtocolBase, TProtocolException, checkIntegerLimits
from struct import pack, unpack
from ..compat import binary_to_str, str_to_binary
__all__ = ['TCompactProtocol', 'TCompactProtocolFactory']
CLEAR = 0
FIELD_WRITE = 1
VALUE_WRITE = 2
CONTAINER_WRITE = 3
BOOL_WRITE = 4
FIELD_READ = 5
CONTAINER_READ = 6
VALUE_READ = 7
BOOL_READ = 8
def make_helper(v_from, container):
def helper(func):
def nested(self, *args, **kwargs):
assert self.state in (v_from, container), (self.state, v_from, container)
return func(self, *args, **kwargs)
return nested
return helper
writer = make_helper(VALUE_WRITE, CONTAINER_WRITE)
reader = make_helper(VALUE_READ, CONTAINER_READ)
def makeZigZag(n, bits):
checkIntegerLimits(n, bits)
return (n << 1) ^ (n >> (bits - 1))
def fromZigZag(n):
return (n >> 1) ^ -(n & 1)
def writeVarint(trans, n):
out = bytearray()
while True:
if n & ~0x7f == 0:
out.append(n)
break
else:
out.append((n & 0xff) | 0x80)
n = n >> 7
trans.write(bytes(out))
def readVarint(trans):
result = 0
shift = 0
while True:
x = trans.readAll(1)
byte = ord(x)
result |= (byte & 0x7f) << shift
if byte >> 7 == 0:
return result
shift += 7
class CompactType(object):
STOP = 0x00
TRUE = 0x01
FALSE = 0x02
BYTE = 0x03
I16 = 0x04
I32 = 0x05
I64 = 0x06
DOUBLE = 0x07
BINARY = 0x08
LIST = 0x09
SET = 0x0A
MAP = 0x0B
STRUCT = 0x0C
CTYPES = {
TType.STOP: CompactType.STOP,
TType.BOOL: CompactType.TRUE, # used for collection
TType.BYTE: CompactType.BYTE,
TType.I16: CompactType.I16,
TType.I32: CompactType.I32,
TType.I64: CompactType.I64,
TType.DOUBLE: CompactType.DOUBLE,
TType.STRING: CompactType.BINARY,
TType.STRUCT: CompactType.STRUCT,
TType.LIST: CompactType.LIST,
TType.SET: CompactType.SET,
TType.MAP: CompactType.MAP,
}
TTYPES = {}
for k, v in CTYPES.items():
TTYPES[v] = k
TTYPES[CompactType.FALSE] = TType.BOOL
del k
del v
class TCompactProtocol(TProtocolBase):
"""Compact implementation of the Thrift protocol driver."""
PROTOCOL_ID = 0x82
VERSION = 1
VERSION_MASK = 0x1f
TYPE_MASK = 0xe0
TYPE_BITS = 0x07
TYPE_SHIFT_AMOUNT = 5
def __init__(self, trans,
string_length_limit=None,
container_length_limit=None):
TProtocolBase.__init__(self, trans)
self.state = CLEAR
self.__last_fid = 0
self.__bool_fid = None
self.__bool_value = None
self.__structs = []
self.__containers = []
self.string_length_limit = string_length_limit
self.container_length_limit = container_length_limit
def _check_string_length(self, length):
self._check_length(self.string_length_limit, length)
def _check_container_length(self, length):
self._check_length(self.container_length_limit, length)
def __writeVarint(self, n):
writeVarint(self.trans, n)
def writeMessageBegin(self, name, type, seqid):
assert self.state == CLEAR
self.__writeUByte(self.PROTOCOL_ID)
self.__writeUByte(self.VERSION | (type << self.TYPE_SHIFT_AMOUNT))
self.__writeVarint(seqid)
self.__writeBinary(str_to_binary(name))
self.state = VALUE_WRITE
def writeMessageEnd(self):
assert self.state == VALUE_WRITE
self.state = CLEAR
def writeStructBegin(self, name):
assert self.state in (CLEAR, CONTAINER_WRITE, VALUE_WRITE), self.state
self.__structs.append((self.state, self.__last_fid))
self.state = FIELD_WRITE
self.__last_fid = 0
def writeStructEnd(self):
assert self.state == FIELD_WRITE
self.state, self.__last_fid = self.__structs.pop()
def writeFieldStop(self):
self.__writeByte(0)
def __writeFieldHeader(self, type, fid):
delta = fid - self.__last_fid
if 0 < delta <= 15:
self.__writeUByte(delta << 4 | type)
else:
self.__writeByte(type)
self.__writeI16(fid)
self.__last_fid = fid
def writeFieldBegin(self, name, type, fid):
assert self.state == FIELD_WRITE, self.state
if type == TType.BOOL:
self.state = BOOL_WRITE
self.__bool_fid = fid
else:
self.state = VALUE_WRITE
self.__writeFieldHeader(CTYPES[type], fid)
def writeFieldEnd(self):
assert self.state in (VALUE_WRITE, BOOL_WRITE), self.state
self.state = FIELD_WRITE
def __writeUByte(self, byte):
self.trans.write(pack('!B', byte))
def __writeByte(self, byte):
self.trans.write(pack('!b', byte))
def __writeI16(self, i16):
self.__writeVarint(makeZigZag(i16, 16))
def __writeSize(self, i32):
self.__writeVarint(i32)
def writeCollectionBegin(self, etype, size):
assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state
if size <= 14:
self.__writeUByte(size << 4 | CTYPES[etype])
else:
self.__writeUByte(0xf0 | CTYPES[etype])
self.__writeSize(size)
self.__containers.append(self.state)
self.state = CONTAINER_WRITE
writeSetBegin = writeCollectionBegin
writeListBegin = writeCollectionBegin
def writeMapBegin(self, ktype, vtype, size):
assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state
if size == 0:
self.__writeByte(0)
else:
self.__writeSize(size)
self.__writeUByte(CTYPES[ktype] << 4 | CTYPES[vtype])
self.__containers.append(self.state)
self.state = CONTAINER_WRITE
def writeCollectionEnd(self):
assert self.state == CONTAINER_WRITE, self.state
self.state = self.__containers.pop()
writeMapEnd = writeCollectionEnd
writeSetEnd = writeCollectionEnd
writeListEnd = writeCollectionEnd
def writeBool(self, bool):
if self.state == BOOL_WRITE:
if bool:
ctype = CompactType.TRUE
else:
ctype = CompactType.FALSE
self.__writeFieldHeader(ctype, self.__bool_fid)
elif self.state == CONTAINER_WRITE:
if bool:
self.__writeByte(CompactType.TRUE)
else:
self.__writeByte(CompactType.FALSE)
else:
raise AssertionError("Invalid state in compact protocol")
writeByte = writer(__writeByte)
writeI16 = writer(__writeI16)
@writer
def writeI32(self, i32):
self.__writeVarint(makeZigZag(i32, 32))
@writer
def writeI64(self, i64):
self.__writeVarint(makeZigZag(i64, 64))
@writer
def writeDouble(self, dub):
self.trans.write(pack('<d', dub))
def __writeBinary(self, s):
self.__writeSize(len(s))
self.trans.write(s)
writeBinary = writer(__writeBinary)
def readFieldBegin(self):
assert self.state == FIELD_READ, self.state
type = self.__readUByte()
if type & 0x0f == TType.STOP:
return (None, 0, 0)
delta = type >> 4
if delta == 0:
fid = self.__readI16()
else:
fid = self.__last_fid + delta
self.__last_fid = fid
type = type & 0x0f
if type == CompactType.TRUE:
self.state = BOOL_READ
self.__bool_value = True
elif type == CompactType.FALSE:
self.state = BOOL_READ
self.__bool_value = False
else:
self.state = VALUE_READ
return (None, self.__getTType(type), fid)
def readFieldEnd(self):
assert self.state in (VALUE_READ, BOOL_READ), self.state
self.state = FIELD_READ
def __readUByte(self):
result, = unpack('!B', self.trans.readAll(1))
return result
def __readByte(self):
result, = unpack('!b', self.trans.readAll(1))
return result
def __readVarint(self):
return readVarint(self.trans)
def __readZigZag(self):
return fromZigZag(self.__readVarint())
def __readSize(self):
result = self.__readVarint()
if result < 0:
raise TProtocolException("Length < 0")
return result
def readMessageBegin(self):
assert self.state == CLEAR
proto_id = self.__readUByte()
if proto_id != self.PROTOCOL_ID:
raise TProtocolException(TProtocolException.BAD_VERSION,
'Bad protocol id in the message: %d' % proto_id)
ver_type = self.__readUByte()
type = (ver_type >> self.TYPE_SHIFT_AMOUNT) & self.TYPE_BITS
version = ver_type & self.VERSION_MASK
if version != self.VERSION:
raise TProtocolException(TProtocolException.BAD_VERSION,
'Bad version: %d (expect %d)' % (version, self.VERSION))
seqid = self.__readVarint()
name = binary_to_str(self.__readBinary())
return (name, type, seqid)
def readMessageEnd(self):
assert self.state == CLEAR
assert len(self.__structs) == 0
def readStructBegin(self):
assert self.state in (CLEAR, CONTAINER_READ, VALUE_READ), self.state
self.__structs.append((self.state, self.__last_fid))
self.state = FIELD_READ
self.__last_fid = 0
def readStructEnd(self):
assert self.state == FIELD_READ
self.state, self.__last_fid = self.__structs.pop()
def readCollectionBegin(self):
assert self.state in (VALUE_READ, CONTAINER_READ), self.state
size_type = self.__readUByte()
size = size_type >> 4
type = self.__getTType(size_type)
if size == 15:
size = self.__readSize()
self._check_container_length(size)
self.__containers.append(self.state)
self.state = CONTAINER_READ
return type, size
readSetBegin = readCollectionBegin
readListBegin = readCollectionBegin
def readMapBegin(self):
assert self.state in (VALUE_READ, CONTAINER_READ), self.state
size = self.__readSize()
self._check_container_length(size)
types = 0
if size > 0:
types = self.__readUByte()
vtype = self.__getTType(types)
ktype = self.__getTType(types >> 4)
self.__containers.append(self.state)
self.state = CONTAINER_READ
return (ktype, vtype, size)
def readCollectionEnd(self):
assert self.state == CONTAINER_READ, self.state
self.state = self.__containers.pop()
readSetEnd = readCollectionEnd
readListEnd = readCollectionEnd
readMapEnd = readCollectionEnd
def readBool(self):
if self.state == BOOL_READ:
return self.__bool_value == CompactType.TRUE
elif self.state == CONTAINER_READ:
return self.__readByte() == CompactType.TRUE
else:
raise AssertionError("Invalid state in compact protocol: %d" %
self.state)
readByte = reader(__readByte)
__readI16 = __readZigZag
readI16 = reader(__readZigZag)
readI32 = reader(__readZigZag)
readI64 = reader(__readZigZag)
@reader
def readDouble(self):
buff = self.trans.readAll(8)
val, = unpack('<d', buff)
return val
def __readBinary(self):
size = self.__readSize()
self._check_string_length(size)
return self.trans.readAll(size)
readBinary = reader(__readBinary)
def __getTType(self, byte):
return TTYPES[byte & 0x0f]
class TCompactProtocolFactory(object):
def __init__(self,
string_length_limit=None,
container_length_limit=None):
self.string_length_limit = string_length_limit
self.container_length_limit = container_length_limit
def getProtocol(self, trans):
return TCompactProtocol(trans,
self.string_length_limit,
self.container_length_limit)
class TCompactProtocolAccelerated(TCompactProtocol):
"""C-Accelerated version of TCompactProtocol.
This class does not override any of TCompactProtocol's methods,
but the generated code recognizes it directly and will call into
our C module to do the encoding, bypassing this object entirely.
We inherit from TCompactProtocol so that the normal TCompactProtocol
encoding can happen if the fastbinary module doesn't work for some
reason.
To disable this behavior, pass fallback=False constructor argument.
In order to take advantage of the C module, just use
TCompactProtocolAccelerated instead of TCompactProtocol.
"""
pass
def __init__(self, *args, **kwargs):
fallback = kwargs.pop('fallback', True)
super(TCompactProtocolAccelerated, self).__init__(*args, **kwargs)
try:
from thrift.protocol import fastbinary
except ImportError:
if not fallback:
raise
else:
self._fast_decode = fastbinary.decode_compact
self._fast_encode = fastbinary.encode_compact
class TCompactProtocolAcceleratedFactory(object):
def __init__(self,
string_length_limit=None,
container_length_limit=None,
fallback=True):
self.string_length_limit = string_length_limit
self.container_length_limit = container_length_limit
self._fallback = fallback
def getProtocol(self, trans):
return TCompactProtocolAccelerated(
trans,
string_length_limit=self.string_length_limit,
container_length_limit=self.container_length_limit,
fallback=self._fallback)

View file

@ -0,0 +1,677 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
from .TProtocol import (TType, TProtocolBase, TProtocolException,
checkIntegerLimits)
import base64
import math
import sys
from ..compat import str_to_binary
__all__ = ['TJSONProtocol',
'TJSONProtocolFactory',
'TSimpleJSONProtocol',
'TSimpleJSONProtocolFactory']
VERSION = 1
COMMA = b','
COLON = b':'
LBRACE = b'{'
RBRACE = b'}'
LBRACKET = b'['
RBRACKET = b']'
QUOTE = b'"'
BACKSLASH = b'\\'
ZERO = b'0'
ESCSEQ0 = ord('\\')
ESCSEQ1 = ord('u')
ESCAPE_CHAR_VALS = {
'"': '\\"',
'\\': '\\\\',
'\b': '\\b',
'\f': '\\f',
'\n': '\\n',
'\r': '\\r',
'\t': '\\t',
# '/': '\\/',
}
ESCAPE_CHARS = {
b'"': '"',
b'\\': '\\',
b'b': '\b',
b'f': '\f',
b'n': '\n',
b'r': '\r',
b't': '\t',
b'/': '/',
}
NUMERIC_CHAR = b'+-.0123456789Ee'
CTYPES = {
TType.BOOL: 'tf',
TType.BYTE: 'i8',
TType.I16: 'i16',
TType.I32: 'i32',
TType.I64: 'i64',
TType.DOUBLE: 'dbl',
TType.STRING: 'str',
TType.STRUCT: 'rec',
TType.LIST: 'lst',
TType.SET: 'set',
TType.MAP: 'map',
}
JTYPES = {}
for key in CTYPES.keys():
JTYPES[CTYPES[key]] = key
class JSONBaseContext(object):
def __init__(self, protocol):
self.protocol = protocol
self.first = True
def doIO(self, function):
pass
def write(self):
pass
def read(self):
pass
def escapeNum(self):
return False
def __str__(self):
return self.__class__.__name__
class JSONListContext(JSONBaseContext):
def doIO(self, function):
if self.first is True:
self.first = False
else:
function(COMMA)
def write(self):
self.doIO(self.protocol.trans.write)
def read(self):
self.doIO(self.protocol.readJSONSyntaxChar)
class JSONPairContext(JSONBaseContext):
def __init__(self, protocol):
super(JSONPairContext, self).__init__(protocol)
self.colon = True
def doIO(self, function):
if self.first:
self.first = False
self.colon = True
else:
function(COLON if self.colon else COMMA)
self.colon = not self.colon
def write(self):
self.doIO(self.protocol.trans.write)
def read(self):
self.doIO(self.protocol.readJSONSyntaxChar)
def escapeNum(self):
return self.colon
def __str__(self):
return '%s, colon=%s' % (self.__class__.__name__, self.colon)
class LookaheadReader():
hasData = False
data = ''
def __init__(self, protocol):
self.protocol = protocol
def read(self):
if self.hasData is True:
self.hasData = False
else:
self.data = self.protocol.trans.read(1)
return self.data
def peek(self):
if self.hasData is False:
self.data = self.protocol.trans.read(1)
self.hasData = True
return self.data
class TJSONProtocolBase(TProtocolBase):
def __init__(self, trans):
TProtocolBase.__init__(self, trans)
self.resetWriteContext()
self.resetReadContext()
# We don't have length limit implementation for JSON protocols
@property
def string_length_limit(senf):
return None
@property
def container_length_limit(senf):
return None
def resetWriteContext(self):
self.context = JSONBaseContext(self)
self.contextStack = [self.context]
def resetReadContext(self):
self.resetWriteContext()
self.reader = LookaheadReader(self)
def pushContext(self, ctx):
self.contextStack.append(ctx)
self.context = ctx
def popContext(self):
self.contextStack.pop()
if self.contextStack:
self.context = self.contextStack[-1]
else:
self.context = JSONBaseContext(self)
def writeJSONString(self, string):
self.context.write()
json_str = ['"']
for s in string:
escaped = ESCAPE_CHAR_VALS.get(s, s)
json_str.append(escaped)
json_str.append('"')
self.trans.write(str_to_binary(''.join(json_str)))
def writeJSONNumber(self, number, formatter='{0}'):
self.context.write()
jsNumber = str(formatter.format(number)).encode('ascii')
if self.context.escapeNum():
self.trans.write(QUOTE)
self.trans.write(jsNumber)
self.trans.write(QUOTE)
else:
self.trans.write(jsNumber)
def writeJSONBase64(self, binary):
self.context.write()
self.trans.write(QUOTE)
self.trans.write(base64.b64encode(binary))
self.trans.write(QUOTE)
def writeJSONObjectStart(self):
self.context.write()
self.trans.write(LBRACE)
self.pushContext(JSONPairContext(self))
def writeJSONObjectEnd(self):
self.popContext()
self.trans.write(RBRACE)
def writeJSONArrayStart(self):
self.context.write()
self.trans.write(LBRACKET)
self.pushContext(JSONListContext(self))
def writeJSONArrayEnd(self):
self.popContext()
self.trans.write(RBRACKET)
def readJSONSyntaxChar(self, character):
current = self.reader.read()
if character != current:
raise TProtocolException(TProtocolException.INVALID_DATA,
"Unexpected character: %s" % current)
def _isHighSurrogate(self, codeunit):
return codeunit >= 0xd800 and codeunit <= 0xdbff
def _isLowSurrogate(self, codeunit):
return codeunit >= 0xdc00 and codeunit <= 0xdfff
def _toChar(self, high, low=None):
if not low:
if sys.version_info[0] == 2:
return ("\\u%04x" % high).decode('unicode-escape') \
.encode('utf-8')
else:
return chr(high)
else:
codepoint = (1 << 16) + ((high & 0x3ff) << 10)
codepoint += low & 0x3ff
if sys.version_info[0] == 2:
s = "\\U%08x" % codepoint
return s.decode('unicode-escape').encode('utf-8')
else:
return chr(codepoint)
def readJSONString(self, skipContext):
highSurrogate = None
string = []
if skipContext is False:
self.context.read()
self.readJSONSyntaxChar(QUOTE)
while True:
character = self.reader.read()
if character == QUOTE:
break
if ord(character) == ESCSEQ0:
character = self.reader.read()
if ord(character) == ESCSEQ1:
character = self.trans.read(4).decode('ascii')
codeunit = int(character, 16)
if self._isHighSurrogate(codeunit):
if highSurrogate:
raise TProtocolException(
TProtocolException.INVALID_DATA,
"Expected low surrogate char")
highSurrogate = codeunit
continue
elif self._isLowSurrogate(codeunit):
if not highSurrogate:
raise TProtocolException(
TProtocolException.INVALID_DATA,
"Expected high surrogate char")
character = self._toChar(highSurrogate, codeunit)
highSurrogate = None
else:
character = self._toChar(codeunit)
else:
if character not in ESCAPE_CHARS:
raise TProtocolException(
TProtocolException.INVALID_DATA,
"Expected control char")
character = ESCAPE_CHARS[character]
elif character in ESCAPE_CHAR_VALS:
raise TProtocolException(TProtocolException.INVALID_DATA,
"Unescaped control char")
elif sys.version_info[0] > 2:
utf8_bytes = bytearray([ord(character)])
while ord(self.reader.peek()) >= 0x80:
utf8_bytes.append(ord(self.reader.read()))
character = utf8_bytes.decode('utf8')
string.append(character)
if highSurrogate:
raise TProtocolException(TProtocolException.INVALID_DATA,
"Expected low surrogate char")
return ''.join(string)
def isJSONNumeric(self, character):
return (True if NUMERIC_CHAR.find(character) != - 1 else False)
def readJSONQuotes(self):
if (self.context.escapeNum()):
self.readJSONSyntaxChar(QUOTE)
def readJSONNumericChars(self):
numeric = []
while True:
character = self.reader.peek()
if self.isJSONNumeric(character) is False:
break
numeric.append(self.reader.read())
return b''.join(numeric).decode('ascii')
def readJSONInteger(self):
self.context.read()
self.readJSONQuotes()
numeric = self.readJSONNumericChars()
self.readJSONQuotes()
try:
return int(numeric)
except ValueError:
raise TProtocolException(TProtocolException.INVALID_DATA,
"Bad data encounted in numeric data")
def readJSONDouble(self):
self.context.read()
if self.reader.peek() == QUOTE:
string = self.readJSONString(True)
try:
double = float(string)
if (self.context.escapeNum is False and
not math.isinf(double) and
not math.isnan(double)):
raise TProtocolException(
TProtocolException.INVALID_DATA,
"Numeric data unexpectedly quoted")
return double
except ValueError:
raise TProtocolException(TProtocolException.INVALID_DATA,
"Bad data encounted in numeric data")
else:
if self.context.escapeNum() is True:
self.readJSONSyntaxChar(QUOTE)
try:
return float(self.readJSONNumericChars())
except ValueError:
raise TProtocolException(TProtocolException.INVALID_DATA,
"Bad data encounted in numeric data")
def readJSONBase64(self):
string = self.readJSONString(False)
size = len(string)
m = size % 4
# Force padding since b64encode method does not allow it
if m != 0:
for i in range(4 - m):
string += '='
return base64.b64decode(string)
def readJSONObjectStart(self):
self.context.read()
self.readJSONSyntaxChar(LBRACE)
self.pushContext(JSONPairContext(self))
def readJSONObjectEnd(self):
self.readJSONSyntaxChar(RBRACE)
self.popContext()
def readJSONArrayStart(self):
self.context.read()
self.readJSONSyntaxChar(LBRACKET)
self.pushContext(JSONListContext(self))
def readJSONArrayEnd(self):
self.readJSONSyntaxChar(RBRACKET)
self.popContext()
class TJSONProtocol(TJSONProtocolBase):
def readMessageBegin(self):
self.resetReadContext()
self.readJSONArrayStart()
if self.readJSONInteger() != VERSION:
raise TProtocolException(TProtocolException.BAD_VERSION,
"Message contained bad version.")
name = self.readJSONString(False)
typen = self.readJSONInteger()
seqid = self.readJSONInteger()
return (name, typen, seqid)
def readMessageEnd(self):
self.readJSONArrayEnd()
def readStructBegin(self):
self.readJSONObjectStart()
def readStructEnd(self):
self.readJSONObjectEnd()
def readFieldBegin(self):
character = self.reader.peek()
ttype = 0
id = 0
if character == RBRACE:
ttype = TType.STOP
else:
id = self.readJSONInteger()
self.readJSONObjectStart()
ttype = JTYPES[self.readJSONString(False)]
return (None, ttype, id)
def readFieldEnd(self):
self.readJSONObjectEnd()
def readMapBegin(self):
self.readJSONArrayStart()
keyType = JTYPES[self.readJSONString(False)]
valueType = JTYPES[self.readJSONString(False)]
size = self.readJSONInteger()
self.readJSONObjectStart()
return (keyType, valueType, size)
def readMapEnd(self):
self.readJSONObjectEnd()
self.readJSONArrayEnd()
def readCollectionBegin(self):
self.readJSONArrayStart()
elemType = JTYPES[self.readJSONString(False)]
size = self.readJSONInteger()
return (elemType, size)
readListBegin = readCollectionBegin
readSetBegin = readCollectionBegin
def readCollectionEnd(self):
self.readJSONArrayEnd()
readSetEnd = readCollectionEnd
readListEnd = readCollectionEnd
def readBool(self):
return (False if self.readJSONInteger() == 0 else True)
def readNumber(self):
return self.readJSONInteger()
readByte = readNumber
readI16 = readNumber
readI32 = readNumber
readI64 = readNumber
def readDouble(self):
return self.readJSONDouble()
def readString(self):
return self.readJSONString(False)
def readBinary(self):
return self.readJSONBase64()
def writeMessageBegin(self, name, request_type, seqid):
self.resetWriteContext()
self.writeJSONArrayStart()
self.writeJSONNumber(VERSION)
self.writeJSONString(name)
self.writeJSONNumber(request_type)
self.writeJSONNumber(seqid)
def writeMessageEnd(self):
self.writeJSONArrayEnd()
def writeStructBegin(self, name):
self.writeJSONObjectStart()
def writeStructEnd(self):
self.writeJSONObjectEnd()
def writeFieldBegin(self, name, ttype, id):
self.writeJSONNumber(id)
self.writeJSONObjectStart()
self.writeJSONString(CTYPES[ttype])
def writeFieldEnd(self):
self.writeJSONObjectEnd()
def writeFieldStop(self):
pass
def writeMapBegin(self, ktype, vtype, size):
self.writeJSONArrayStart()
self.writeJSONString(CTYPES[ktype])
self.writeJSONString(CTYPES[vtype])
self.writeJSONNumber(size)
self.writeJSONObjectStart()
def writeMapEnd(self):
self.writeJSONObjectEnd()
self.writeJSONArrayEnd()
def writeListBegin(self, etype, size):
self.writeJSONArrayStart()
self.writeJSONString(CTYPES[etype])
self.writeJSONNumber(size)
def writeListEnd(self):
self.writeJSONArrayEnd()
def writeSetBegin(self, etype, size):
self.writeJSONArrayStart()
self.writeJSONString(CTYPES[etype])
self.writeJSONNumber(size)
def writeSetEnd(self):
self.writeJSONArrayEnd()
def writeBool(self, boolean):
self.writeJSONNumber(1 if boolean is True else 0)
def writeByte(self, byte):
checkIntegerLimits(byte, 8)
self.writeJSONNumber(byte)
def writeI16(self, i16):
checkIntegerLimits(i16, 16)
self.writeJSONNumber(i16)
def writeI32(self, i32):
checkIntegerLimits(i32, 32)
self.writeJSONNumber(i32)
def writeI64(self, i64):
checkIntegerLimits(i64, 64)
self.writeJSONNumber(i64)
def writeDouble(self, dbl):
# 17 significant digits should be just enough for any double precision
# value.
self.writeJSONNumber(dbl, '{0:.17g}')
def writeString(self, string):
self.writeJSONString(string)
def writeBinary(self, binary):
self.writeJSONBase64(binary)
class TJSONProtocolFactory(object):
def getProtocol(self, trans):
return TJSONProtocol(trans)
@property
def string_length_limit(senf):
return None
@property
def container_length_limit(senf):
return None
class TSimpleJSONProtocol(TJSONProtocolBase):
"""Simple, readable, write-only JSON protocol.
Useful for interacting with scripting languages.
"""
def readMessageBegin(self):
raise NotImplementedError()
def readMessageEnd(self):
raise NotImplementedError()
def readStructBegin(self):
raise NotImplementedError()
def readStructEnd(self):
raise NotImplementedError()
def writeMessageBegin(self, name, request_type, seqid):
self.resetWriteContext()
def writeMessageEnd(self):
pass
def writeStructBegin(self, name):
self.writeJSONObjectStart()
def writeStructEnd(self):
self.writeJSONObjectEnd()
def writeFieldBegin(self, name, ttype, fid):
self.writeJSONString(name)
def writeFieldEnd(self):
pass
def writeMapBegin(self, ktype, vtype, size):
self.writeJSONObjectStart()
def writeMapEnd(self):
self.writeJSONObjectEnd()
def _writeCollectionBegin(self, etype, size):
self.writeJSONArrayStart()
def _writeCollectionEnd(self):
self.writeJSONArrayEnd()
writeListBegin = _writeCollectionBegin
writeListEnd = _writeCollectionEnd
writeSetBegin = _writeCollectionBegin
writeSetEnd = _writeCollectionEnd
def writeByte(self, byte):
checkIntegerLimits(byte, 8)
self.writeJSONNumber(byte)
def writeI16(self, i16):
checkIntegerLimits(i16, 16)
self.writeJSONNumber(i16)
def writeI32(self, i32):
checkIntegerLimits(i32, 32)
self.writeJSONNumber(i32)
def writeI64(self, i64):
checkIntegerLimits(i64, 64)
self.writeJSONNumber(i64)
def writeBool(self, boolean):
self.writeJSONNumber(1 if boolean is True else 0)
def writeDouble(self, dbl):
self.writeJSONNumber(dbl)
def writeString(self, string):
self.writeJSONString(string)
def writeBinary(self, binary):
self.writeJSONBase64(binary)
class TSimpleJSONProtocolFactory(object):
def getProtocol(self, trans):
return TSimpleJSONProtocol(trans)

View file

@ -0,0 +1,40 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
from thrift.Thrift import TMessageType
from thrift.protocol import TProtocolDecorator
SEPARATOR = ":"
class TMultiplexedProtocol(TProtocolDecorator.TProtocolDecorator):
def __init__(self, protocol, serviceName):
TProtocolDecorator.TProtocolDecorator.__init__(self, protocol)
self.serviceName = serviceName
def writeMessageBegin(self, name, type, seqid):
if (type == TMessageType.CALL or
type == TMessageType.ONEWAY):
self.protocol.writeMessageBegin(
self.serviceName + SEPARATOR + name,
type,
seqid
)
else:
self.protocol.writeMessageBegin(name, type, seqid)

View file

@ -0,0 +1,419 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
from thrift.Thrift import TException, TType, TFrozenDict
from thrift.transport.TTransport import TTransportException
from ..compat import binary_to_str, str_to_binary
import six
import sys
from itertools import islice
from six.moves import zip
class TProtocolException(TException):
"""Custom Protocol Exception class"""
UNKNOWN = 0
INVALID_DATA = 1
NEGATIVE_SIZE = 2
SIZE_LIMIT = 3
BAD_VERSION = 4
NOT_IMPLEMENTED = 5
DEPTH_LIMIT = 6
def __init__(self, type=UNKNOWN, message=None):
TException.__init__(self, message)
self.type = type
class TProtocolBase(object):
"""Base class for Thrift protocol driver."""
def __init__(self, trans):
self.trans = trans
self._fast_decode = None
self._fast_encode = None
@staticmethod
def _check_length(limit, length):
if length < 0:
raise TTransportException(TTransportException.NEGATIVE_SIZE,
'Negative length: %d' % length)
if limit is not None and length > limit:
raise TTransportException(TTransportException.SIZE_LIMIT,
'Length exceeded max allowed: %d' % limit)
def writeMessageBegin(self, name, ttype, seqid):
pass
def writeMessageEnd(self):
pass
def writeStructBegin(self, name):
pass
def writeStructEnd(self):
pass
def writeFieldBegin(self, name, ttype, fid):
pass
def writeFieldEnd(self):
pass
def writeFieldStop(self):
pass
def writeMapBegin(self, ktype, vtype, size):
pass
def writeMapEnd(self):
pass
def writeListBegin(self, etype, size):
pass
def writeListEnd(self):
pass
def writeSetBegin(self, etype, size):
pass
def writeSetEnd(self):
pass
def writeBool(self, bool_val):
pass
def writeByte(self, byte):
pass
def writeI16(self, i16):
pass
def writeI32(self, i32):
pass
def writeI64(self, i64):
pass
def writeDouble(self, dub):
pass
def writeString(self, str_val):
self.writeBinary(str_to_binary(str_val))
def writeBinary(self, str_val):
pass
def writeUtf8(self, str_val):
self.writeString(str_val.encode('utf8'))
def readMessageBegin(self):
pass
def readMessageEnd(self):
pass
def readStructBegin(self):
pass
def readStructEnd(self):
pass
def readFieldBegin(self):
pass
def readFieldEnd(self):
pass
def readMapBegin(self):
pass
def readMapEnd(self):
pass
def readListBegin(self):
pass
def readListEnd(self):
pass
def readSetBegin(self):
pass
def readSetEnd(self):
pass
def readBool(self):
pass
def readByte(self):
pass
def readI16(self):
pass
def readI32(self):
pass
def readI64(self):
pass
def readDouble(self):
pass
def readString(self):
return binary_to_str(self.readBinary())
def readBinary(self):
pass
def readUtf8(self):
return self.readString().decode('utf8')
def skip(self, ttype):
if ttype == TType.STOP:
return
elif ttype == TType.BOOL:
self.readBool()
elif ttype == TType.BYTE:
self.readByte()
elif ttype == TType.I16:
self.readI16()
elif ttype == TType.I32:
self.readI32()
elif ttype == TType.I64:
self.readI64()
elif ttype == TType.DOUBLE:
self.readDouble()
elif ttype == TType.STRING:
self.readString()
elif ttype == TType.STRUCT:
name = self.readStructBegin()
while True:
(name, ttype, id) = self.readFieldBegin()
if ttype == TType.STOP:
break
self.skip(ttype)
self.readFieldEnd()
self.readStructEnd()
elif ttype == TType.MAP:
(ktype, vtype, size) = self.readMapBegin()
for i in range(size):
self.skip(ktype)
self.skip(vtype)
self.readMapEnd()
elif ttype == TType.SET:
(etype, size) = self.readSetBegin()
for i in range(size):
self.skip(etype)
self.readSetEnd()
elif ttype == TType.LIST:
(etype, size) = self.readListBegin()
for i in range(size):
self.skip(etype)
self.readListEnd()
# tuple of: ( 'reader method' name, is_container bool, 'writer_method' name )
_TTYPE_HANDLERS = (
(None, None, False), # 0 TType.STOP
(None, None, False), # 1 TType.VOID # TODO: handle void?
('readBool', 'writeBool', False), # 2 TType.BOOL
('readByte', 'writeByte', False), # 3 TType.BYTE and I08
('readDouble', 'writeDouble', False), # 4 TType.DOUBLE
(None, None, False), # 5 undefined
('readI16', 'writeI16', False), # 6 TType.I16
(None, None, False), # 7 undefined
('readI32', 'writeI32', False), # 8 TType.I32
(None, None, False), # 9 undefined
('readI64', 'writeI64', False), # 10 TType.I64
('readString', 'writeString', False), # 11 TType.STRING and UTF7
('readContainerStruct', 'writeContainerStruct', True), # 12 *.STRUCT
('readContainerMap', 'writeContainerMap', True), # 13 TType.MAP
('readContainerSet', 'writeContainerSet', True), # 14 TType.SET
('readContainerList', 'writeContainerList', True), # 15 TType.LIST
(None, None, False), # 16 TType.UTF8 # TODO: handle utf8 types?
(None, None, False) # 17 TType.UTF16 # TODO: handle utf16 types?
)
def _ttype_handlers(self, ttype, spec):
if spec == 'BINARY':
if ttype != TType.STRING:
raise TProtocolException(type=TProtocolException.INVALID_DATA,
message='Invalid binary field type %d' % ttype)
return ('readBinary', 'writeBinary', False)
if sys.version_info[0] == 2 and spec == 'UTF8':
if ttype != TType.STRING:
raise TProtocolException(type=TProtocolException.INVALID_DATA,
message='Invalid string field type %d' % ttype)
return ('readUtf8', 'writeUtf8', False)
return self._TTYPE_HANDLERS[ttype] if ttype < len(self._TTYPE_HANDLERS) else (None, None, False)
def _read_by_ttype(self, ttype, spec, espec):
reader_name, _, is_container = self._ttype_handlers(ttype, spec)
if reader_name is None:
raise TProtocolException(type=TProtocolException.INVALID_DATA,
message='Invalid type %d' % (ttype))
reader_func = getattr(self, reader_name)
read = (lambda: reader_func(espec)) if is_container else reader_func
while True:
yield read()
def readFieldByTType(self, ttype, spec):
return next(self._read_by_ttype(ttype, spec, spec))
def readContainerList(self, spec):
ttype, tspec, is_immutable = spec
(list_type, list_len) = self.readListBegin()
# TODO: compare types we just decoded with thrift_spec
elems = islice(self._read_by_ttype(ttype, spec, tspec), list_len)
results = (tuple if is_immutable else list)(elems)
self.readListEnd()
return results
def readContainerSet(self, spec):
ttype, tspec, is_immutable = spec
(set_type, set_len) = self.readSetBegin()
# TODO: compare types we just decoded with thrift_spec
elems = islice(self._read_by_ttype(ttype, spec, tspec), set_len)
results = (frozenset if is_immutable else set)(elems)
self.readSetEnd()
return results
def readContainerStruct(self, spec):
(obj_class, obj_spec) = spec
obj = obj_class()
obj.read(self)
return obj
def readContainerMap(self, spec):
ktype, kspec, vtype, vspec, is_immutable = spec
(map_ktype, map_vtype, map_len) = self.readMapBegin()
# TODO: compare types we just decoded with thrift_spec and
# abort/skip if types disagree
keys = self._read_by_ttype(ktype, spec, kspec)
vals = self._read_by_ttype(vtype, spec, vspec)
keyvals = islice(zip(keys, vals), map_len)
results = (TFrozenDict if is_immutable else dict)(keyvals)
self.readMapEnd()
return results
def readStruct(self, obj, thrift_spec, is_immutable=False):
if is_immutable:
fields = {}
self.readStructBegin()
while True:
(fname, ftype, fid) = self.readFieldBegin()
if ftype == TType.STOP:
break
try:
field = thrift_spec[fid]
except IndexError:
self.skip(ftype)
else:
if field is not None and ftype == field[1]:
fname = field[2]
fspec = field[3]
val = self.readFieldByTType(ftype, fspec)
if is_immutable:
fields[fname] = val
else:
setattr(obj, fname, val)
else:
self.skip(ftype)
self.readFieldEnd()
self.readStructEnd()
if is_immutable:
return obj(**fields)
def writeContainerStruct(self, val, spec):
val.write(self)
def writeContainerList(self, val, spec):
ttype, tspec, _ = spec
self.writeListBegin(ttype, len(val))
for _ in self._write_by_ttype(ttype, val, spec, tspec):
pass
self.writeListEnd()
def writeContainerSet(self, val, spec):
ttype, tspec, _ = spec
self.writeSetBegin(ttype, len(val))
for _ in self._write_by_ttype(ttype, val, spec, tspec):
pass
self.writeSetEnd()
def writeContainerMap(self, val, spec):
ktype, kspec, vtype, vspec, _ = spec
self.writeMapBegin(ktype, vtype, len(val))
for _ in zip(self._write_by_ttype(ktype, six.iterkeys(val), spec, kspec),
self._write_by_ttype(vtype, six.itervalues(val), spec, vspec)):
pass
self.writeMapEnd()
def writeStruct(self, obj, thrift_spec):
self.writeStructBegin(obj.__class__.__name__)
for field in thrift_spec:
if field is None:
continue
fname = field[2]
val = getattr(obj, fname)
if val is None:
# skip writing out unset fields
continue
fid = field[0]
ftype = field[1]
fspec = field[3]
self.writeFieldBegin(fname, ftype, fid)
self.writeFieldByTType(ftype, val, fspec)
self.writeFieldEnd()
self.writeFieldStop()
self.writeStructEnd()
def _write_by_ttype(self, ttype, vals, spec, espec):
_, writer_name, is_container = self._ttype_handlers(ttype, spec)
writer_func = getattr(self, writer_name)
write = (lambda v: writer_func(v, espec)) if is_container else writer_func
for v in vals:
yield write(v)
def writeFieldByTType(self, ttype, val, spec):
next(self._write_by_ttype(ttype, [val], spec, spec))
def checkIntegerLimits(i, bits):
if bits == 8 and (i < -128 or i > 127):
raise TProtocolException(TProtocolException.INVALID_DATA,
"i8 requires -128 <= number <= 127")
elif bits == 16 and (i < -32768 or i > 32767):
raise TProtocolException(TProtocolException.INVALID_DATA,
"i16 requires -32768 <= number <= 32767")
elif bits == 32 and (i < -2147483648 or i > 2147483647):
raise TProtocolException(TProtocolException.INVALID_DATA,
"i32 requires -2147483648 <= number <= 2147483647")
elif bits == 64 and (i < -9223372036854775808 or i > 9223372036854775807):
raise TProtocolException(TProtocolException.INVALID_DATA,
"i64 requires -9223372036854775808 <= number <= 9223372036854775807")
class TProtocolFactory(object):
def getProtocol(self, trans):
pass

View file

@ -0,0 +1,50 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
import types
from thrift.protocol.TProtocol import TProtocolBase
class TProtocolDecorator():
def __init__(self, protocol):
TProtocolBase(protocol)
self.protocol = protocol
def __getattr__(self, name):
if hasattr(self.protocol, name):
member = getattr(self.protocol, name)
if type(member) in [
types.MethodType,
types.FunctionType,
types.LambdaType,
types.BuiltinFunctionType,
types.BuiltinMethodType,
]:
return lambda *args, **kwargs: self._wrap(member, args, kwargs)
else:
return member
raise AttributeError(name)
def _wrap(self, func, args, kwargs):
if isinstance(func, types.MethodType):
result = func(*args, **kwargs)
else:
result = func(self.protocol, *args, **kwargs)
return result

View file

@ -0,0 +1,21 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
__all__ = ['fastbinary', 'TBase', 'TBinaryProtocol', 'TCompactProtocol',
'TJSONProtocol', 'TProtocol']

View file

@ -0,0 +1,87 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
from six.moves import BaseHTTPServer
from thrift.server import TServer
from thrift.transport import TTransport
class ResponseException(Exception):
"""Allows handlers to override the HTTP response
Normally, THttpServer always sends a 200 response. If a handler wants
to override this behavior (e.g., to simulate a misconfigured or
overloaded web server during testing), it can raise a ResponseException.
The function passed to the constructor will be called with the
RequestHandler as its only argument.
"""
def __init__(self, handler):
self.handler = handler
class THttpServer(TServer.TServer):
"""A simple HTTP-based Thrift server
This class is not very performant, but it is useful (for example) for
acting as a mock version of an Apache-based PHP Thrift endpoint.
"""
def __init__(self,
processor,
server_address,
inputProtocolFactory,
outputProtocolFactory=None,
server_class=BaseHTTPServer.HTTPServer):
"""Set up protocol factories and HTTP server.
See BaseHTTPServer for server_address.
See TServer for protocol factories.
"""
if outputProtocolFactory is None:
outputProtocolFactory = inputProtocolFactory
TServer.TServer.__init__(self, processor, None, None, None,
inputProtocolFactory, outputProtocolFactory)
thttpserver = self
class RequestHander(BaseHTTPServer.BaseHTTPRequestHandler):
def do_POST(self):
# Don't care about the request path.
itrans = TTransport.TFileObjectTransport(self.rfile)
otrans = TTransport.TFileObjectTransport(self.wfile)
itrans = TTransport.TBufferedTransport(
itrans, int(self.headers['Content-Length']))
otrans = TTransport.TMemoryBuffer()
iprot = thttpserver.inputProtocolFactory.getProtocol(itrans)
oprot = thttpserver.outputProtocolFactory.getProtocol(otrans)
try:
thttpserver.processor.process(iprot, oprot)
except ResponseException as exn:
exn.handler(self)
else:
self.send_response(200)
self.send_header("content-type", "application/x-thrift")
self.end_headers()
self.wfile.write(otrans.getvalue())
self.httpd = server_class(server_address, RequestHander)
def serve(self):
self.httpd.serve_forever()

View file

@ -0,0 +1,350 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
"""Implementation of non-blocking server.
The main idea of the server is to receive and send requests
only from the main thread.
The thread poool should be sized for concurrent tasks, not
maximum connections
"""
import logging
import select
import socket
import struct
import threading
from six.moves import queue
from thrift.transport import TTransport
from thrift.protocol.TBinaryProtocol import TBinaryProtocolFactory
__all__ = ['TNonblockingServer']
logger = logging.getLogger(__name__)
class Worker(threading.Thread):
"""Worker is a small helper to process incoming connection."""
def __init__(self, queue):
threading.Thread.__init__(self)
self.queue = queue
def run(self):
"""Process queries from task queue, stop if processor is None."""
while True:
try:
processor, iprot, oprot, otrans, callback = self.queue.get()
if processor is None:
break
processor.process(iprot, oprot)
callback(True, otrans.getvalue())
except Exception:
logger.exception("Exception while processing request")
callback(False, b'')
WAIT_LEN = 0
WAIT_MESSAGE = 1
WAIT_PROCESS = 2
SEND_ANSWER = 3
CLOSED = 4
def locked(func):
"""Decorator which locks self.lock."""
def nested(self, *args, **kwargs):
self.lock.acquire()
try:
return func(self, *args, **kwargs)
finally:
self.lock.release()
return nested
def socket_exception(func):
"""Decorator close object on socket.error."""
def read(self, *args, **kwargs):
try:
return func(self, *args, **kwargs)
except socket.error:
self.close()
return read
class Connection(object):
"""Basic class is represented connection.
It can be in state:
WAIT_LEN --- connection is reading request len.
WAIT_MESSAGE --- connection is reading request.
WAIT_PROCESS --- connection has just read whole request and
waits for call ready routine.
SEND_ANSWER --- connection is sending answer string (including length
of answer).
CLOSED --- socket was closed and connection should be deleted.
"""
def __init__(self, new_socket, wake_up):
self.socket = new_socket
self.socket.setblocking(False)
self.status = WAIT_LEN
self.len = 0
self.message = b''
self.lock = threading.Lock()
self.wake_up = wake_up
def _read_len(self):
"""Reads length of request.
It's a safer alternative to self.socket.recv(4)
"""
read = self.socket.recv(4 - len(self.message))
if len(read) == 0:
# if we read 0 bytes and self.message is empty, then
# the client closed the connection
if len(self.message) != 0:
logger.error("can't read frame size from socket")
self.close()
return
self.message += read
if len(self.message) == 4:
self.len, = struct.unpack('!i', self.message)
if self.len < 0:
logger.error("negative frame size, it seems client "
"doesn't use FramedTransport")
self.close()
elif self.len == 0:
logger.error("empty frame, it's really strange")
self.close()
else:
self.message = b''
self.status = WAIT_MESSAGE
@socket_exception
def read(self):
"""Reads data from stream and switch state."""
assert self.status in (WAIT_LEN, WAIT_MESSAGE)
if self.status == WAIT_LEN:
self._read_len()
# go back to the main loop here for simplicity instead of
# falling through, even though there is a good chance that
# the message is already available
elif self.status == WAIT_MESSAGE:
read = self.socket.recv(self.len - len(self.message))
if len(read) == 0:
logger.error("can't read frame from socket (get %d of "
"%d bytes)" % (len(self.message), self.len))
self.close()
return
self.message += read
if len(self.message) == self.len:
self.status = WAIT_PROCESS
@socket_exception
def write(self):
"""Writes data from socket and switch state."""
assert self.status == SEND_ANSWER
sent = self.socket.send(self.message)
if sent == len(self.message):
self.status = WAIT_LEN
self.message = b''
self.len = 0
else:
self.message = self.message[sent:]
@locked
def ready(self, all_ok, message):
"""Callback function for switching state and waking up main thread.
This function is the only function witch can be called asynchronous.
The ready can switch Connection to three states:
WAIT_LEN if request was oneway.
SEND_ANSWER if request was processed in normal way.
CLOSED if request throws unexpected exception.
The one wakes up main thread.
"""
assert self.status == WAIT_PROCESS
if not all_ok:
self.close()
self.wake_up()
return
self.len = 0
if len(message) == 0:
# it was a oneway request, do not write answer
self.message = b''
self.status = WAIT_LEN
else:
self.message = struct.pack('!i', len(message)) + message
self.status = SEND_ANSWER
self.wake_up()
@locked
def is_writeable(self):
"""Return True if connection should be added to write list of select"""
return self.status == SEND_ANSWER
# it's not necessary, but...
@locked
def is_readable(self):
"""Return True if connection should be added to read list of select"""
return self.status in (WAIT_LEN, WAIT_MESSAGE)
@locked
def is_closed(self):
"""Returns True if connection is closed."""
return self.status == CLOSED
def fileno(self):
"""Returns the file descriptor of the associated socket."""
return self.socket.fileno()
def close(self):
"""Closes connection"""
self.status = CLOSED
self.socket.close()
class TNonblockingServer(object):
"""Non-blocking server."""
def __init__(self,
processor,
lsocket,
inputProtocolFactory=None,
outputProtocolFactory=None,
threads=10):
self.processor = processor
self.socket = lsocket
self.in_protocol = inputProtocolFactory or TBinaryProtocolFactory()
self.out_protocol = outputProtocolFactory or self.in_protocol
self.threads = int(threads)
self.clients = {}
self.tasks = queue.Queue()
self._read, self._write = socket.socketpair()
self.prepared = False
self._stop = False
def setNumThreads(self, num):
"""Set the number of worker threads that should be created."""
# implement ThreadPool interface
assert not self.prepared, "Can't change number of threads after start"
self.threads = num
def prepare(self):
"""Prepares server for serve requests."""
if self.prepared:
return
self.socket.listen()
for _ in range(self.threads):
thread = Worker(self.tasks)
thread.setDaemon(True)
thread.start()
self.prepared = True
def wake_up(self):
"""Wake up main thread.
The server usually waits in select call in we should terminate one.
The simplest way is using socketpair.
Select always wait to read from the first socket of socketpair.
In this case, we can just write anything to the second socket from
socketpair.
"""
self._write.send(b'1')
def stop(self):
"""Stop the server.
This method causes the serve() method to return. stop() may be invoked
from within your handler, or from another thread.
After stop() is called, serve() will return but the server will still
be listening on the socket. serve() may then be called again to resume
processing requests. Alternatively, close() may be called after
serve() returns to close the server socket and shutdown all worker
threads.
"""
self._stop = True
self.wake_up()
def _select(self):
"""Does select on open connections."""
readable = [self.socket.handle.fileno(), self._read.fileno()]
writable = []
for i, connection in list(self.clients.items()):
if connection.is_readable():
readable.append(connection.fileno())
if connection.is_writeable():
writable.append(connection.fileno())
if connection.is_closed():
del self.clients[i]
return select.select(readable, writable, readable)
def handle(self):
"""Handle requests.
WARNING! You must call prepare() BEFORE calling handle()
"""
assert self.prepared, "You have to call prepare before handle"
rset, wset, xset = self._select()
for readable in rset:
if readable == self._read.fileno():
# don't care i just need to clean readable flag
self._read.recv(1024)
elif readable == self.socket.handle.fileno():
client = self.socket.accept().handle
self.clients[client.fileno()] = Connection(client,
self.wake_up)
else:
connection = self.clients[readable]
connection.read()
if connection.status == WAIT_PROCESS:
itransport = TTransport.TMemoryBuffer(connection.message)
otransport = TTransport.TMemoryBuffer()
iprot = self.in_protocol.getProtocol(itransport)
oprot = self.out_protocol.getProtocol(otransport)
self.tasks.put([self.processor, iprot, oprot,
otransport, connection.ready])
for writeable in wset:
self.clients[writeable].write()
for oob in xset:
self.clients[oob].close()
del self.clients[oob]
def close(self):
"""Closes the server."""
for _ in range(self.threads):
self.tasks.put([None, None, None, None, None])
self.socket.close()
self.prepared = False
def serve(self):
"""Serve requests.
Serve requests forever, or until stop() is called.
"""
self._stop = False
self.prepare()
while not self._stop:
self.handle()

View file

@ -0,0 +1,123 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
import logging
from multiprocessing import Process, Value, Condition
from .TServer import TServer
from thrift.transport.TTransport import TTransportException
logger = logging.getLogger(__name__)
class TProcessPoolServer(TServer):
"""Server with a fixed size pool of worker subprocesses to service requests
Note that if you need shared state between the handlers - it's up to you!
Written by Dvir Volk, doat.com
"""
def __init__(self, *args):
TServer.__init__(self, *args)
self.numWorkers = 10
self.workers = []
self.isRunning = Value('b', False)
self.stopCondition = Condition()
self.postForkCallback = None
def setPostForkCallback(self, callback):
if not callable(callback):
raise TypeError("This is not a callback!")
self.postForkCallback = callback
def setNumWorkers(self, num):
"""Set the number of worker threads that should be created"""
self.numWorkers = num
def workerProcess(self):
"""Loop getting clients from the shared queue and process them"""
if self.postForkCallback:
self.postForkCallback()
while self.isRunning.value:
try:
client = self.serverTransport.accept()
if not client:
continue
self.serveClient(client)
except (KeyboardInterrupt, SystemExit):
return 0
except Exception as x:
logger.exception(x)
def serveClient(self, client):
"""Process input/output from a client for as long as possible"""
itrans = self.inputTransportFactory.getTransport(client)
otrans = self.outputTransportFactory.getTransport(client)
iprot = self.inputProtocolFactory.getProtocol(itrans)
oprot = self.outputProtocolFactory.getProtocol(otrans)
try:
while True:
self.processor.process(iprot, oprot)
except TTransportException:
pass
except Exception as x:
logger.exception(x)
itrans.close()
otrans.close()
def serve(self):
"""Start workers and put into queue"""
# this is a shared state that can tell the workers to exit when False
self.isRunning.value = True
# first bind and listen to the port
self.serverTransport.listen()
# fork the children
for i in range(self.numWorkers):
try:
w = Process(target=self.workerProcess)
w.daemon = True
w.start()
self.workers.append(w)
except Exception as x:
logger.exception(x)
# wait until the condition is set by stop()
while True:
self.stopCondition.acquire()
try:
self.stopCondition.wait()
break
except (SystemExit, KeyboardInterrupt):
break
except Exception as x:
logger.exception(x)
self.isRunning.value = False
def stop(self):
self.isRunning.value = False
self.stopCondition.acquire()
self.stopCondition.notify()
self.stopCondition.release()

View file

@ -0,0 +1,276 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
from six.moves import queue
import logging
import os
import threading
from thrift.protocol import TBinaryProtocol
from thrift.transport import TTransport
logger = logging.getLogger(__name__)
class TServer(object):
"""Base interface for a server, which must have a serve() method.
Three constructors for all servers:
1) (processor, serverTransport)
2) (processor, serverTransport, transportFactory, protocolFactory)
3) (processor, serverTransport,
inputTransportFactory, outputTransportFactory,
inputProtocolFactory, outputProtocolFactory)
"""
def __init__(self, *args):
if (len(args) == 2):
self.__initArgs__(args[0], args[1],
TTransport.TTransportFactoryBase(),
TTransport.TTransportFactoryBase(),
TBinaryProtocol.TBinaryProtocolFactory(),
TBinaryProtocol.TBinaryProtocolFactory())
elif (len(args) == 4):
self.__initArgs__(args[0], args[1], args[2], args[2], args[3], args[3])
elif (len(args) == 6):
self.__initArgs__(args[0], args[1], args[2], args[3], args[4], args[5])
def __initArgs__(self, processor, serverTransport,
inputTransportFactory, outputTransportFactory,
inputProtocolFactory, outputProtocolFactory):
self.processor = processor
self.serverTransport = serverTransport
self.inputTransportFactory = inputTransportFactory
self.outputTransportFactory = outputTransportFactory
self.inputProtocolFactory = inputProtocolFactory
self.outputProtocolFactory = outputProtocolFactory
def serve(self):
pass
class TSimpleServer(TServer):
"""Simple single-threaded server that just pumps around one transport."""
def __init__(self, *args):
TServer.__init__(self, *args)
def serve(self):
self.serverTransport.listen()
while True:
client = self.serverTransport.accept()
if not client:
continue
itrans = self.inputTransportFactory.getTransport(client)
otrans = self.outputTransportFactory.getTransport(client)
iprot = self.inputProtocolFactory.getProtocol(itrans)
oprot = self.outputProtocolFactory.getProtocol(otrans)
try:
while True:
self.processor.process(iprot, oprot)
except TTransport.TTransportException:
pass
except Exception as x:
logger.exception(x)
itrans.close()
otrans.close()
class TThreadedServer(TServer):
"""Threaded server that spawns a new thread per each connection."""
def __init__(self, *args, **kwargs):
TServer.__init__(self, *args)
self.daemon = kwargs.get("daemon", False)
def serve(self):
self.serverTransport.listen()
while True:
try:
client = self.serverTransport.accept()
if not client:
continue
t = threading.Thread(target=self.handle, args=(client,))
t.setDaemon(self.daemon)
t.start()
except KeyboardInterrupt:
raise
except Exception as x:
logger.exception(x)
def handle(self, client):
itrans = self.inputTransportFactory.getTransport(client)
otrans = self.outputTransportFactory.getTransport(client)
iprot = self.inputProtocolFactory.getProtocol(itrans)
oprot = self.outputProtocolFactory.getProtocol(otrans)
try:
while True:
self.processor.process(iprot, oprot)
except TTransport.TTransportException:
pass
except Exception as x:
logger.exception(x)
itrans.close()
otrans.close()
class TThreadPoolServer(TServer):
"""Server with a fixed size pool of threads which service requests."""
def __init__(self, *args, **kwargs):
TServer.__init__(self, *args)
self.clients = queue.Queue()
self.threads = 10
self.daemon = kwargs.get("daemon", False)
def setNumThreads(self, num):
"""Set the number of worker threads that should be created"""
self.threads = num
def serveThread(self):
"""Loop around getting clients from the shared queue and process them."""
while True:
try:
client = self.clients.get()
self.serveClient(client)
except Exception as x:
logger.exception(x)
def serveClient(self, client):
"""Process input/output from a client for as long as possible"""
itrans = self.inputTransportFactory.getTransport(client)
otrans = self.outputTransportFactory.getTransport(client)
iprot = self.inputProtocolFactory.getProtocol(itrans)
oprot = self.outputProtocolFactory.getProtocol(otrans)
try:
while True:
self.processor.process(iprot, oprot)
except TTransport.TTransportException:
pass
except Exception as x:
logger.exception(x)
itrans.close()
otrans.close()
def serve(self):
"""Start a fixed number of worker threads and put client into a queue"""
for i in range(self.threads):
try:
t = threading.Thread(target=self.serveThread)
t.setDaemon(self.daemon)
t.start()
except Exception as x:
logger.exception(x)
# Pump the socket for clients
self.serverTransport.listen()
while True:
try:
client = self.serverTransport.accept()
if not client:
continue
self.clients.put(client)
except Exception as x:
logger.exception(x)
class TForkingServer(TServer):
"""A Thrift server that forks a new process for each request
This is more scalable than the threaded server as it does not cause
GIL contention.
Note that this has different semantics from the threading server.
Specifically, updates to shared variables will no longer be shared.
It will also not work on windows.
This code is heavily inspired by SocketServer.ForkingMixIn in the
Python stdlib.
"""
def __init__(self, *args):
TServer.__init__(self, *args)
self.children = []
def serve(self):
def try_close(file):
try:
file.close()
except IOError as e:
logger.warning(e, exc_info=True)
self.serverTransport.listen()
while True:
client = self.serverTransport.accept()
if not client:
continue
try:
pid = os.fork()
if pid: # parent
# add before collect, otherwise you race w/ waitpid
self.children.append(pid)
self.collect_children()
# Parent must close socket or the connection may not get
# closed promptly
itrans = self.inputTransportFactory.getTransport(client)
otrans = self.outputTransportFactory.getTransport(client)
try_close(itrans)
try_close(otrans)
else:
itrans = self.inputTransportFactory.getTransport(client)
otrans = self.outputTransportFactory.getTransport(client)
iprot = self.inputProtocolFactory.getProtocol(itrans)
oprot = self.outputProtocolFactory.getProtocol(otrans)
ecode = 0
try:
try:
while True:
self.processor.process(iprot, oprot)
except TTransport.TTransportException:
pass
except Exception as e:
logger.exception(e)
ecode = 1
finally:
try_close(itrans)
try_close(otrans)
os._exit(ecode)
except TTransport.TTransportException:
pass
except Exception as x:
logger.exception(x)
def collect_children(self):
while self.children:
try:
pid, status = os.waitpid(0, os.WNOHANG)
except os.error:
pid = None
if pid:
self.children.remove(pid)
else:
break

View file

@ -0,0 +1,20 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
__all__ = ['TServer', 'TNonblockingServer']

View file

@ -0,0 +1,194 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
from io import BytesIO
import os
import socket
import sys
import warnings
import base64
from six.moves import urllib
from six.moves import http_client
from .TTransport import TTransportBase
import six
class THttpClient(TTransportBase):
"""Http implementation of TTransport base."""
def __init__(self, uri_or_host, port=None, path=None):
"""THttpClient supports two different types constructor parameters.
THttpClient(host, port, path) - deprecated
THttpClient(uri)
Only the second supports https.
"""
if port is not None:
warnings.warn(
"Please use the THttpClient('http://host:port/path') syntax",
DeprecationWarning,
stacklevel=2)
self.host = uri_or_host
self.port = port
assert path
self.path = path
self.scheme = 'http'
else:
parsed = urllib.parse.urlparse(uri_or_host)
self.scheme = parsed.scheme
assert self.scheme in ('http', 'https')
if self.scheme == 'http':
self.port = parsed.port or http_client.HTTP_PORT
elif self.scheme == 'https':
self.port = parsed.port or http_client.HTTPS_PORT
self.host = parsed.hostname
self.path = parsed.path
if parsed.query:
self.path += '?%s' % parsed.query
try:
proxy = urllib.request.getproxies()[self.scheme]
except KeyError:
proxy = None
else:
if urllib.request.proxy_bypass(self.host):
proxy = None
if proxy:
parsed = urllib.parse.urlparse(proxy)
self.realhost = self.host
self.realport = self.port
self.host = parsed.hostname
self.port = parsed.port
self.proxy_auth = self.basic_proxy_auth_header(parsed)
else:
self.realhost = self.realport = self.proxy_auth = None
self.__wbuf = BytesIO()
self.__http = None
self.__http_response = None
self.__timeout = None
self.__custom_headers = None
@staticmethod
def basic_proxy_auth_header(proxy):
if proxy is None or not proxy.username:
return None
ap = "%s:%s" % (urllib.parse.unquote(proxy.username),
urllib.parse.unquote(proxy.password))
cr = base64.b64encode(ap).strip()
return "Basic " + cr
def using_proxy(self):
return self.realhost is not None
def open(self):
if self.scheme == 'http':
self.__http = http_client.HTTPConnection(self.host, self.port)
elif self.scheme == 'https':
self.__http = http_client.HTTPSConnection(self.host, self.port)
if self.using_proxy():
self.__http.set_tunnel(self.realhost, self.realport,
{"Proxy-Authorization": self.proxy_auth})
def close(self):
self.__http.close()
self.__http = None
self.__http_response = None
def isOpen(self):
return self.__http is not None
def setTimeout(self, ms):
if not hasattr(socket, 'getdefaulttimeout'):
raise NotImplementedError
if ms is None:
self.__timeout = None
else:
self.__timeout = ms / 1000.0
def setCustomHeaders(self, headers):
self.__custom_headers = headers
def read(self, sz):
return self.__http_response.read(sz)
def write(self, buf):
self.__wbuf.write(buf)
def __withTimeout(f):
def _f(*args, **kwargs):
orig_timeout = socket.getdefaulttimeout()
socket.setdefaulttimeout(args[0].__timeout)
try:
result = f(*args, **kwargs)
finally:
socket.setdefaulttimeout(orig_timeout)
return result
return _f
def flush(self):
if self.isOpen():
self.close()
self.open()
# Pull data out of buffer
data = self.__wbuf.getvalue()
self.__wbuf = BytesIO()
# HTTP request
if self.using_proxy() and self.scheme == "http":
# need full URL of real host for HTTP proxy here (HTTPS uses CONNECT tunnel)
self.__http.putrequest('POST', "http://%s:%s%s" %
(self.realhost, self.realport, self.path))
else:
self.__http.putrequest('POST', self.path)
# Write headers
self.__http.putheader('Content-Type', 'application/x-thrift')
self.__http.putheader('Content-Length', str(len(data)))
if self.using_proxy() and self.scheme == "http" and self.proxy_auth is not None:
self.__http.putheader("Proxy-Authorization", self.proxy_auth)
if not self.__custom_headers or 'User-Agent' not in self.__custom_headers:
user_agent = 'Python/THttpClient'
script = os.path.basename(sys.argv[0])
if script:
user_agent = '%s (%s)' % (user_agent, urllib.parse.quote(script))
self.__http.putheader('User-Agent', user_agent)
if self.__custom_headers:
for key, val in six.iteritems(self.__custom_headers):
self.__http.putheader(key, val)
self.__http.endheaders()
# Write payload
self.__http.send(data)
# Get reply to flush the request
self.__http_response = self.__http.getresponse()
self.code = self.__http_response.status
self.message = self.__http_response.reason
self.headers = self.__http_response.msg
# Decorate if we know how to timeout
if hasattr(socket, 'getdefaulttimeout'):
flush = __withTimeout(flush)

View file

@ -0,0 +1,396 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
import logging
import os
import socket
import ssl
import sys
import warnings
from .sslcompat import _match_hostname, _match_has_ipaddress
from thrift.transport import TSocket
from thrift.transport.TTransport import TTransportException
logger = logging.getLogger(__name__)
warnings.filterwarnings(
'default', category=DeprecationWarning, module=__name__)
class TSSLBase(object):
# SSLContext is not available for Python < 2.7.9
_has_ssl_context = sys.hexversion >= 0x020709F0
# ciphers argument is not available for Python < 2.7.0
_has_ciphers = sys.hexversion >= 0x020700F0
# For pythoon >= 2.7.9, use latest TLS that both client and server
# supports.
# SSL 2.0 and 3.0 are disabled via ssl.OP_NO_SSLv2 and ssl.OP_NO_SSLv3.
# For pythoon < 2.7.9, use TLS 1.0 since TLSv1_X nor OP_NO_SSLvX is
# unavailable.
_default_protocol = ssl.PROTOCOL_SSLv23 if _has_ssl_context else \
ssl.PROTOCOL_TLSv1
def _init_context(self, ssl_version):
if self._has_ssl_context:
self._context = ssl.SSLContext(ssl_version)
if self._context.protocol == ssl.PROTOCOL_SSLv23:
self._context.options |= ssl.OP_NO_SSLv2
self._context.options |= ssl.OP_NO_SSLv3
else:
self._context = None
self._ssl_version = ssl_version
@property
def _should_verify(self):
if self._has_ssl_context:
return self._context.verify_mode != ssl.CERT_NONE
else:
return self.cert_reqs != ssl.CERT_NONE
@property
def ssl_version(self):
if self._has_ssl_context:
return self.ssl_context.protocol
else:
return self._ssl_version
@property
def ssl_context(self):
return self._context
SSL_VERSION = _default_protocol
"""
Default SSL version.
For backword compatibility, it can be modified.
Use __init__ keywoard argument "ssl_version" instead.
"""
def _deprecated_arg(self, args, kwargs, pos, key):
if len(args) <= pos:
return
real_pos = pos + 3
warnings.warn(
'%dth positional argument is deprecated.'
'please use keyward argument insteand.'
% real_pos, DeprecationWarning, stacklevel=3)
if key in kwargs:
raise TypeError(
'Duplicate argument: %dth argument and %s keyward argument.'
% (real_pos, key))
kwargs[key] = args[pos]
def _unix_socket_arg(self, host, port, args, kwargs):
key = 'unix_socket'
if host is None and port is None and len(args) == 1 and key not in kwargs:
kwargs[key] = args[0]
return True
return False
def __getattr__(self, key):
if key == 'SSL_VERSION':
warnings.warn(
'SSL_VERSION is deprecated.'
'please use ssl_version attribute instead.',
DeprecationWarning, stacklevel=2)
return self.ssl_version
def __init__(self, server_side, host, ssl_opts):
self._server_side = server_side
if TSSLBase.SSL_VERSION != self._default_protocol:
warnings.warn(
'SSL_VERSION is deprecated.'
'please use ssl_version keyward argument instead.',
DeprecationWarning, stacklevel=2)
self._context = ssl_opts.pop('ssl_context', None)
self._server_hostname = None
if not self._server_side:
self._server_hostname = ssl_opts.pop('server_hostname', host)
if self._context:
self._custom_context = True
if ssl_opts:
raise ValueError(
'Incompatible arguments: ssl_context and %s'
% ' '.join(ssl_opts.keys()))
if not self._has_ssl_context:
raise ValueError(
'ssl_context is not available for this version of Python')
else:
self._custom_context = False
ssl_version = ssl_opts.pop('ssl_version', TSSLBase.SSL_VERSION)
self._init_context(ssl_version)
self.cert_reqs = ssl_opts.pop('cert_reqs', ssl.CERT_REQUIRED)
self.ca_certs = ssl_opts.pop('ca_certs', None)
self.keyfile = ssl_opts.pop('keyfile', None)
self.certfile = ssl_opts.pop('certfile', None)
self.ciphers = ssl_opts.pop('ciphers', None)
if ssl_opts:
raise ValueError(
'Unknown keyword arguments: ', ' '.join(ssl_opts.keys()))
if self._should_verify:
if not self.ca_certs:
raise ValueError(
'ca_certs is needed when cert_reqs is not ssl.CERT_NONE')
if not os.access(self.ca_certs, os.R_OK):
raise IOError('Certificate Authority ca_certs file "%s" '
'is not readable, cannot validate SSL '
'certificates.' % (self.ca_certs))
@property
def certfile(self):
return self._certfile
@certfile.setter
def certfile(self, certfile):
if self._server_side and not certfile:
raise ValueError('certfile is needed for server-side')
if certfile and not os.access(certfile, os.R_OK):
raise IOError('No such certfile found: %s' % (certfile))
self._certfile = certfile
def _wrap_socket(self, sock):
if self._has_ssl_context:
if not self._custom_context:
self.ssl_context.verify_mode = self.cert_reqs
if self.certfile:
self.ssl_context.load_cert_chain(self.certfile,
self.keyfile)
if self.ciphers:
self.ssl_context.set_ciphers(self.ciphers)
if self.ca_certs:
self.ssl_context.load_verify_locations(self.ca_certs)
return self.ssl_context.wrap_socket(
sock, server_side=self._server_side,
server_hostname=self._server_hostname)
else:
ssl_opts = {
'ssl_version': self._ssl_version,
'server_side': self._server_side,
'ca_certs': self.ca_certs,
'keyfile': self.keyfile,
'certfile': self.certfile,
'cert_reqs': self.cert_reqs,
}
if self.ciphers:
if self._has_ciphers:
ssl_opts['ciphers'] = self.ciphers
else:
logger.warning(
'ciphers is specified but ignored due to old Python version')
return ssl.wrap_socket(sock, **ssl_opts)
class TSSLSocket(TSocket.TSocket, TSSLBase):
"""
SSL implementation of TSocket
This class creates outbound sockets wrapped using the
python standard ssl module for encrypted connections.
"""
# New signature
# def __init__(self, host='localhost', port=9090, unix_socket=None,
# **ssl_args):
# Deprecated signature
# def __init__(self, host='localhost', port=9090, validate=True,
# ca_certs=None, keyfile=None, certfile=None,
# unix_socket=None, ciphers=None):
def __init__(self, host='localhost', port=9090, *args, **kwargs):
"""Positional arguments: ``host``, ``port``, ``unix_socket``
Keyword arguments: ``keyfile``, ``certfile``, ``cert_reqs``,
``ssl_version``, ``ca_certs``,
``ciphers`` (Python 2.7.0 or later),
``server_hostname`` (Python 2.7.9 or later)
Passed to ssl.wrap_socket. See ssl.wrap_socket documentation.
Alternative keyword arguments: (Python 2.7.9 or later)
``ssl_context``: ssl.SSLContext to be used for SSLContext.wrap_socket
``server_hostname``: Passed to SSLContext.wrap_socket
Common keyword argument:
``validate_callback`` (cert, hostname) -> None:
Called after SSL handshake. Can raise when hostname does not
match the cert.
"""
self.is_valid = False
self.peercert = None
if args:
if len(args) > 6:
raise TypeError('Too many positional argument')
if not self._unix_socket_arg(host, port, args, kwargs):
self._deprecated_arg(args, kwargs, 0, 'validate')
self._deprecated_arg(args, kwargs, 1, 'ca_certs')
self._deprecated_arg(args, kwargs, 2, 'keyfile')
self._deprecated_arg(args, kwargs, 3, 'certfile')
self._deprecated_arg(args, kwargs, 4, 'unix_socket')
self._deprecated_arg(args, kwargs, 5, 'ciphers')
validate = kwargs.pop('validate', None)
if validate is not None:
cert_reqs_name = 'CERT_REQUIRED' if validate else 'CERT_NONE'
warnings.warn(
'validate is deprecated. please use cert_reqs=ssl.%s instead'
% cert_reqs_name,
DeprecationWarning, stacklevel=2)
if 'cert_reqs' in kwargs:
raise TypeError('Cannot specify both validate and cert_reqs')
kwargs['cert_reqs'] = ssl.CERT_REQUIRED if validate else ssl.CERT_NONE
unix_socket = kwargs.pop('unix_socket', None)
self._validate_callback = kwargs.pop('validate_callback', _match_hostname)
TSSLBase.__init__(self, False, host, kwargs)
TSocket.TSocket.__init__(self, host, port, unix_socket)
@property
def validate(self):
warnings.warn('validate is deprecated. please use cert_reqs instead',
DeprecationWarning, stacklevel=2)
return self.cert_reqs != ssl.CERT_NONE
@validate.setter
def validate(self, value):
warnings.warn('validate is deprecated. please use cert_reqs instead',
DeprecationWarning, stacklevel=2)
self.cert_reqs = ssl.CERT_REQUIRED if value else ssl.CERT_NONE
def _do_open(self, family, socktype):
plain_sock = socket.socket(family, socktype)
try:
return self._wrap_socket(plain_sock)
except Exception:
plain_sock.close()
msg = 'failed to initialize SSL'
logger.exception(msg)
raise TTransportException(TTransportException.NOT_OPEN, msg)
def open(self):
super(TSSLSocket, self).open()
if self._should_verify:
self.peercert = self.handle.getpeercert()
try:
self._validate_callback(self.peercert, self._server_hostname)
self.is_valid = True
except TTransportException:
raise
except Exception as ex:
raise TTransportException(TTransportException.UNKNOWN, str(ex))
class TSSLServerSocket(TSocket.TServerSocket, TSSLBase):
"""SSL implementation of TServerSocket
This uses the ssl module's wrap_socket() method to provide SSL
negotiated encryption.
"""
# New signature
# def __init__(self, host='localhost', port=9090, unix_socket=None, **ssl_args):
# Deprecated signature
# def __init__(self, host=None, port=9090, certfile='cert.pem', unix_socket=None, ciphers=None):
def __init__(self, host=None, port=9090, *args, **kwargs):
"""Positional arguments: ``host``, ``port``, ``unix_socket``
Keyword arguments: ``keyfile``, ``certfile``, ``cert_reqs``, ``ssl_version``,
``ca_certs``, ``ciphers`` (Python 2.7.0 or later)
See ssl.wrap_socket documentation.
Alternative keyword arguments: (Python 2.7.9 or later)
``ssl_context``: ssl.SSLContext to be used for SSLContext.wrap_socket
``server_hostname``: Passed to SSLContext.wrap_socket
Common keyword argument:
``validate_callback`` (cert, hostname) -> None:
Called after SSL handshake. Can raise when hostname does not
match the cert.
"""
if args:
if len(args) > 3:
raise TypeError('Too many positional argument')
if not self._unix_socket_arg(host, port, args, kwargs):
self._deprecated_arg(args, kwargs, 0, 'certfile')
self._deprecated_arg(args, kwargs, 1, 'unix_socket')
self._deprecated_arg(args, kwargs, 2, 'ciphers')
if 'ssl_context' not in kwargs:
# Preserve existing behaviors for default values
if 'cert_reqs' not in kwargs:
kwargs['cert_reqs'] = ssl.CERT_NONE
if'certfile' not in kwargs:
kwargs['certfile'] = 'cert.pem'
unix_socket = kwargs.pop('unix_socket', None)
self._validate_callback = \
kwargs.pop('validate_callback', _match_hostname)
TSSLBase.__init__(self, True, None, kwargs)
TSocket.TServerSocket.__init__(self, host, port, unix_socket)
if self._should_verify and not _match_has_ipaddress:
raise ValueError('Need ipaddress and backports.ssl_match_hostname '
'module to verify client certificate')
def setCertfile(self, certfile):
"""Set or change the server certificate file used to wrap new
connections.
@param certfile: The filename of the server certificate,
i.e. '/etc/certs/server.pem'
@type certfile: str
Raises an IOError exception if the certfile is not present or unreadable.
"""
warnings.warn(
'setCertfile is deprecated. please use certfile property instead.',
DeprecationWarning, stacklevel=2)
self.certfile = certfile
def accept(self):
plain_client, addr = self.handle.accept()
try:
client = self._wrap_socket(plain_client)
except ssl.SSLError:
logger.exception('Error while accepting from %s', addr)
# failed handshake/ssl wrap, close socket to client
plain_client.close()
# raise
# We can't raise the exception, because it kills most TServer derived
# serve() methods.
# Instead, return None, and let the TServer instance deal with it in
# other exception handling. (but TSimpleServer dies anyway)
return None
if self._should_verify:
client.peercert = client.getpeercert()
try:
self._validate_callback(client.peercert, addr[0])
client.is_valid = True
except Exception:
logger.warn('Failed to validate client certificate address: %s',
addr[0], exc_info=True)
client.close()
plain_client.close()
return None
result = TSocket.TSocket()
result.handle = client
return result

View file

@ -0,0 +1,192 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
import errno
import logging
import os
import socket
import sys
from .TTransport import TTransportBase, TTransportException, TServerTransportBase
logger = logging.getLogger(__name__)
class TSocketBase(TTransportBase):
def _resolveAddr(self):
if self._unix_socket is not None:
return [(socket.AF_UNIX, socket.SOCK_STREAM, None, None,
self._unix_socket)]
else:
return socket.getaddrinfo(self.host,
self.port,
self._socket_family,
socket.SOCK_STREAM,
0,
socket.AI_PASSIVE | socket.AI_ADDRCONFIG)
def close(self):
if self.handle:
self.handle.close()
self.handle = None
class TSocket(TSocketBase):
"""Socket implementation of TTransport base."""
def __init__(self, host='localhost', port=9090, unix_socket=None, socket_family=socket.AF_UNSPEC):
"""Initialize a TSocket
@param host(str) The host to connect to.
@param port(int) The (TCP) port to connect to.
@param unix_socket(str) The filename of a unix socket to connect to.
(host and port will be ignored.)
@param socket_family(int) The socket family to use with this socket.
"""
self.host = host
self.port = port
self.handle = None
self._unix_socket = unix_socket
self._timeout = None
self._socket_family = socket_family
def setHandle(self, h):
self.handle = h
def isOpen(self):
return self.handle is not None
def setTimeout(self, ms):
if ms is None:
self._timeout = None
else:
self._timeout = ms / 1000.0
if self.handle is not None:
self.handle.settimeout(self._timeout)
def _do_open(self, family, socktype):
return socket.socket(family, socktype)
@property
def _address(self):
return self._unix_socket if self._unix_socket else '%s:%d' % (self.host, self.port)
def open(self):
if self.handle:
raise TTransportException(TTransportException.ALREADY_OPEN)
try:
addrs = self._resolveAddr()
except socket.gaierror:
msg = 'failed to resolve sockaddr for ' + str(self._address)
logger.exception(msg)
raise TTransportException(TTransportException.NOT_OPEN, msg)
for family, socktype, _, _, sockaddr in addrs:
handle = self._do_open(family, socktype)
handle.settimeout(self._timeout)
try:
handle.connect(sockaddr)
self.handle = handle
return
except socket.error:
handle.close()
logger.info('Could not connect to %s', sockaddr, exc_info=True)
msg = 'Could not connect to any of %s' % list(map(lambda a: a[4],
addrs))
logger.error(msg)
raise TTransportException(TTransportException.NOT_OPEN, msg)
def read(self, sz):
try:
buff = self.handle.recv(sz)
except socket.error as e:
if (e.args[0] == errno.ECONNRESET and
(sys.platform == 'darwin' or sys.platform.startswith('freebsd'))):
# freebsd and Mach don't follow POSIX semantic of recv
# and fail with ECONNRESET if peer performed shutdown.
# See corresponding comment and code in TSocket::read()
# in lib/cpp/src/transport/TSocket.cpp.
self.close()
# Trigger the check to raise the END_OF_FILE exception below.
buff = ''
else:
raise
if len(buff) == 0:
raise TTransportException(type=TTransportException.END_OF_FILE,
message='TSocket read 0 bytes')
return buff
def write(self, buff):
if not self.handle:
raise TTransportException(type=TTransportException.NOT_OPEN,
message='Transport not open')
sent = 0
have = len(buff)
while sent < have:
plus = self.handle.send(buff)
if plus == 0:
raise TTransportException(type=TTransportException.END_OF_FILE,
message='TSocket sent 0 bytes')
sent += plus
buff = buff[plus:]
def flush(self):
pass
class TServerSocket(TSocketBase, TServerTransportBase):
"""Socket implementation of TServerTransport base."""
def __init__(self, host=None, port=9090, unix_socket=None, socket_family=socket.AF_UNSPEC):
self.host = host
self.port = port
self._unix_socket = unix_socket
self._socket_family = socket_family
self.handle = None
def listen(self):
res0 = self._resolveAddr()
socket_family = self._socket_family == socket.AF_UNSPEC and socket.AF_INET6 or self._socket_family
for res in res0:
if res[0] is socket_family or res is res0[-1]:
break
# We need remove the old unix socket if the file exists and
# nobody is listening on it.
if self._unix_socket:
tmp = socket.socket(res[0], res[1])
try:
tmp.connect(res[4])
except socket.error as err:
eno, message = err.args
if eno == errno.ECONNREFUSED:
os.unlink(res[4])
self.handle = socket.socket(res[0], res[1])
self.handle.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
if hasattr(self.handle, 'settimeout'):
self.handle.settimeout(None)
self.handle.bind(res[4])
self.handle.listen(128)
def accept(self):
client, addr = self.handle.accept()
result = TSocket()
result.setHandle(client)
return result

View file

@ -0,0 +1,452 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
from struct import pack, unpack
from thrift.Thrift import TException
from ..compat import BufferIO
class TTransportException(TException):
"""Custom Transport Exception class"""
UNKNOWN = 0
NOT_OPEN = 1
ALREADY_OPEN = 2
TIMED_OUT = 3
END_OF_FILE = 4
NEGATIVE_SIZE = 5
SIZE_LIMIT = 6
def __init__(self, type=UNKNOWN, message=None):
TException.__init__(self, message)
self.type = type
class TTransportBase(object):
"""Base class for Thrift transport layer."""
def isOpen(self):
pass
def open(self):
pass
def close(self):
pass
def read(self, sz):
pass
def readAll(self, sz):
buff = b''
have = 0
while (have < sz):
chunk = self.read(sz - have)
have += len(chunk)
buff += chunk
if len(chunk) == 0:
raise EOFError()
return buff
def write(self, buf):
pass
def flush(self):
pass
# This class should be thought of as an interface.
class CReadableTransport(object):
"""base class for transports that are readable from C"""
# TODO(dreiss): Think about changing this interface to allow us to use
# a (Python, not c) StringIO instead, because it allows
# you to write after reading.
# NOTE: This is a classic class, so properties will NOT work
# correctly for setting.
@property
def cstringio_buf(self):
"""A cStringIO buffer that contains the current chunk we are reading."""
pass
def cstringio_refill(self, partialread, reqlen):
"""Refills cstringio_buf.
Returns the currently used buffer (which can but need not be the same as
the old cstringio_buf). partialread is what the C code has read from the
buffer, and should be inserted into the buffer before any more reads. The
return value must be a new, not borrowed reference. Something along the
lines of self._buf should be fine.
If reqlen bytes can't be read, throw EOFError.
"""
pass
class TServerTransportBase(object):
"""Base class for Thrift server transports."""
def listen(self):
pass
def accept(self):
pass
def close(self):
pass
class TTransportFactoryBase(object):
"""Base class for a Transport Factory"""
def getTransport(self, trans):
return trans
class TBufferedTransportFactory(object):
"""Factory transport that builds buffered transports"""
def getTransport(self, trans):
buffered = TBufferedTransport(trans)
return buffered
class TBufferedTransport(TTransportBase, CReadableTransport):
"""Class that wraps another transport and buffers its I/O.
The implementation uses a (configurable) fixed-size read buffer
but buffers all writes until a flush is performed.
"""
DEFAULT_BUFFER = 4096
def __init__(self, trans, rbuf_size=DEFAULT_BUFFER):
self.__trans = trans
self.__wbuf = BufferIO()
# Pass string argument to initialize read buffer as cStringIO.InputType
self.__rbuf = BufferIO(b'')
self.__rbuf_size = rbuf_size
def isOpen(self):
return self.__trans.isOpen()
def open(self):
return self.__trans.open()
def close(self):
return self.__trans.close()
def read(self, sz):
ret = self.__rbuf.read(sz)
if len(ret) != 0:
return ret
self.__rbuf = BufferIO(self.__trans.read(max(sz, self.__rbuf_size)))
return self.__rbuf.read(sz)
def write(self, buf):
try:
self.__wbuf.write(buf)
except Exception as e:
# on exception reset wbuf so it doesn't contain a partial function call
self.__wbuf = BufferIO()
raise e
self.__wbuf.getvalue()
def flush(self):
out = self.__wbuf.getvalue()
# reset wbuf before write/flush to preserve state on underlying failure
self.__wbuf = BufferIO()
self.__trans.write(out)
self.__trans.flush()
# Implement the CReadableTransport interface.
@property
def cstringio_buf(self):
return self.__rbuf
def cstringio_refill(self, partialread, reqlen):
retstring = partialread
if reqlen < self.__rbuf_size:
# try to make a read of as much as we can.
retstring += self.__trans.read(self.__rbuf_size)
# but make sure we do read reqlen bytes.
if len(retstring) < reqlen:
retstring += self.__trans.readAll(reqlen - len(retstring))
self.__rbuf = BufferIO(retstring)
return self.__rbuf
class TMemoryBuffer(TTransportBase, CReadableTransport):
"""Wraps a cBytesIO object as a TTransport.
NOTE: Unlike the C++ version of this class, you cannot write to it
then immediately read from it. If you want to read from a
TMemoryBuffer, you must either pass a string to the constructor.
TODO(dreiss): Make this work like the C++ version.
"""
def __init__(self, value=None):
"""value -- a value to read from for stringio
If value is set, this will be a transport for reading,
otherwise, it is for writing"""
if value is not None:
self._buffer = BufferIO(value)
else:
self._buffer = BufferIO()
def isOpen(self):
return not self._buffer.closed
def open(self):
pass
def close(self):
self._buffer.close()
def read(self, sz):
return self._buffer.read(sz)
def write(self, buf):
self._buffer.write(buf)
def flush(self):
pass
def getvalue(self):
return self._buffer.getvalue()
# Implement the CReadableTransport interface.
@property
def cstringio_buf(self):
return self._buffer
def cstringio_refill(self, partialread, reqlen):
# only one shot at reading...
raise EOFError()
class TFramedTransportFactory(object):
"""Factory transport that builds framed transports"""
def getTransport(self, trans):
framed = TFramedTransport(trans)
return framed
class TFramedTransport(TTransportBase, CReadableTransport):
"""Class that wraps another transport and frames its I/O when writing."""
def __init__(self, trans,):
self.__trans = trans
self.__rbuf = BufferIO(b'')
self.__wbuf = BufferIO()
def isOpen(self):
return self.__trans.isOpen()
def open(self):
return self.__trans.open()
def close(self):
return self.__trans.close()
def read(self, sz):
ret = self.__rbuf.read(sz)
if len(ret) != 0:
return ret
self.readFrame()
return self.__rbuf.read(sz)
def readFrame(self):
buff = self.__trans.readAll(4)
sz, = unpack('!i', buff)
self.__rbuf = BufferIO(self.__trans.readAll(sz))
def write(self, buf):
self.__wbuf.write(buf)
def flush(self):
wout = self.__wbuf.getvalue()
wsz = len(wout)
# reset wbuf before write/flush to preserve state on underlying failure
self.__wbuf = BufferIO()
# N.B.: Doing this string concatenation is WAY cheaper than making
# two separate calls to the underlying socket object. Socket writes in
# Python turn out to be REALLY expensive, but it seems to do a pretty
# good job of managing string buffer operations without excessive copies
buf = pack("!i", wsz) + wout
self.__trans.write(buf)
self.__trans.flush()
# Implement the CReadableTransport interface.
@property
def cstringio_buf(self):
return self.__rbuf
def cstringio_refill(self, prefix, reqlen):
# self.__rbuf will already be empty here because fastbinary doesn't
# ask for a refill until the previous buffer is empty. Therefore,
# we can start reading new frames immediately.
while len(prefix) < reqlen:
self.readFrame()
prefix += self.__rbuf.getvalue()
self.__rbuf = BufferIO(prefix)
return self.__rbuf
class TFileObjectTransport(TTransportBase):
"""Wraps a file-like object to make it work as a Thrift transport."""
def __init__(self, fileobj):
self.fileobj = fileobj
def isOpen(self):
return True
def close(self):
self.fileobj.close()
def read(self, sz):
return self.fileobj.read(sz)
def write(self, buf):
self.fileobj.write(buf)
def flush(self):
self.fileobj.flush()
class TSaslClientTransport(TTransportBase, CReadableTransport):
"""
SASL transport
"""
START = 1
OK = 2
BAD = 3
ERROR = 4
COMPLETE = 5
def __init__(self, transport, host, service, mechanism='GSSAPI',
**sasl_kwargs):
"""
transport: an underlying transport to use, typically just a TSocket
host: the name of the server, from a SASL perspective
service: the name of the server's service, from a SASL perspective
mechanism: the name of the preferred mechanism to use
All other kwargs will be passed to the puresasl.client.SASLClient
constructor.
"""
from puresasl.client import SASLClient
self.transport = transport
self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs)
self.__wbuf = BufferIO()
self.__rbuf = BufferIO(b'')
def open(self):
if not self.transport.isOpen():
self.transport.open()
self.send_sasl_msg(self.START, self.sasl.mechanism)
self.send_sasl_msg(self.OK, self.sasl.process())
while True:
status, challenge = self.recv_sasl_msg()
if status == self.OK:
self.send_sasl_msg(self.OK, self.sasl.process(challenge))
elif status == self.COMPLETE:
if not self.sasl.complete:
raise TTransportException(
TTransportException.NOT_OPEN,
"The server erroneously indicated "
"that SASL negotiation was complete")
else:
break
else:
raise TTransportException(
TTransportException.NOT_OPEN,
"Bad SASL negotiation status: %d (%s)"
% (status, challenge))
def send_sasl_msg(self, status, body):
header = pack(">BI", status, len(body))
self.transport.write(header + body)
self.transport.flush()
def recv_sasl_msg(self):
header = self.transport.readAll(5)
status, length = unpack(">BI", header)
if length > 0:
payload = self.transport.readAll(length)
else:
payload = ""
return status, payload
def write(self, data):
self.__wbuf.write(data)
def flush(self):
data = self.__wbuf.getvalue()
encoded = self.sasl.wrap(data)
self.transport.write(''.join((pack("!i", len(encoded)), encoded)))
self.transport.flush()
self.__wbuf = BufferIO()
def read(self, sz):
ret = self.__rbuf.read(sz)
if len(ret) != 0:
return ret
self._read_frame()
return self.__rbuf.read(sz)
def _read_frame(self):
header = self.transport.readAll(4)
length, = unpack('!i', header)
encoded = self.transport.readAll(length)
self.__rbuf = BufferIO(self.sasl.unwrap(encoded))
def close(self):
self.sasl.dispose()
self.transport.close()
# based on TFramedTransport
@property
def cstringio_buf(self):
return self.__rbuf
def cstringio_refill(self, prefix, reqlen):
# self.__rbuf will already be empty here because fastbinary doesn't
# ask for a refill until the previous buffer is empty. Therefore,
# we can start reading new frames immediately.
while len(prefix) < reqlen:
self._read_frame()
prefix += self.__rbuf.getvalue()
self.__rbuf = BufferIO(prefix)
return self.__rbuf

View file

@ -0,0 +1,331 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
from io import BytesIO
import struct
from zope.interface import implements, Interface, Attribute
from twisted.internet.protocol import ServerFactory, ClientFactory, \
connectionDone
from twisted.internet import defer
from twisted.internet.threads import deferToThread
from twisted.protocols import basic
from twisted.web import server, resource, http
from thrift.transport import TTransport
class TMessageSenderTransport(TTransport.TTransportBase):
def __init__(self):
self.__wbuf = BytesIO()
def write(self, buf):
self.__wbuf.write(buf)
def flush(self):
msg = self.__wbuf.getvalue()
self.__wbuf = BytesIO()
return self.sendMessage(msg)
def sendMessage(self, message):
raise NotImplementedError
class TCallbackTransport(TMessageSenderTransport):
def __init__(self, func):
TMessageSenderTransport.__init__(self)
self.func = func
def sendMessage(self, message):
return self.func(message)
class ThriftClientProtocol(basic.Int32StringReceiver):
MAX_LENGTH = 2 ** 31 - 1
def __init__(self, client_class, iprot_factory, oprot_factory=None):
self._client_class = client_class
self._iprot_factory = iprot_factory
if oprot_factory is None:
self._oprot_factory = iprot_factory
else:
self._oprot_factory = oprot_factory
self.recv_map = {}
self.started = defer.Deferred()
def dispatch(self, msg):
self.sendString(msg)
def connectionMade(self):
tmo = TCallbackTransport(self.dispatch)
self.client = self._client_class(tmo, self._oprot_factory)
self.started.callback(self.client)
def connectionLost(self, reason=connectionDone):
# the called errbacks can add items to our client's _reqs,
# so we need to use a tmp, and iterate until no more requests
# are added during errbacks
if self.client:
tex = TTransport.TTransportException(
type=TTransport.TTransportException.END_OF_FILE,
message='Connection closed (%s)' % reason)
while self.client._reqs:
_, v = self.client._reqs.popitem()
v.errback(tex)
del self.client._reqs
self.client = None
def stringReceived(self, frame):
tr = TTransport.TMemoryBuffer(frame)
iprot = self._iprot_factory.getProtocol(tr)
(fname, mtype, rseqid) = iprot.readMessageBegin()
try:
method = self.recv_map[fname]
except KeyError:
method = getattr(self.client, 'recv_' + fname)
self.recv_map[fname] = method
method(iprot, mtype, rseqid)
class ThriftSASLClientProtocol(ThriftClientProtocol):
START = 1
OK = 2
BAD = 3
ERROR = 4
COMPLETE = 5
MAX_LENGTH = 2 ** 31 - 1
def __init__(self, client_class, iprot_factory, oprot_factory=None,
host=None, service=None, mechanism='GSSAPI', **sasl_kwargs):
"""
host: the name of the server, from a SASL perspective
service: the name of the server's service, from a SASL perspective
mechanism: the name of the preferred mechanism to use
All other kwargs will be passed to the puresasl.client.SASLClient
constructor.
"""
from puresasl.client import SASLClient
self.SASLCLient = SASLClient
ThriftClientProtocol.__init__(self, client_class, iprot_factory, oprot_factory)
self._sasl_negotiation_deferred = None
self._sasl_negotiation_status = None
self.client = None
if host is not None:
self.createSASLClient(host, service, mechanism, **sasl_kwargs)
def createSASLClient(self, host, service, mechanism, **kwargs):
self.sasl = self.SASLClient(host, service, mechanism, **kwargs)
def dispatch(self, msg):
encoded = self.sasl.wrap(msg)
len_and_encoded = ''.join((struct.pack('!i', len(encoded)), encoded))
ThriftClientProtocol.dispatch(self, len_and_encoded)
@defer.inlineCallbacks
def connectionMade(self):
self._sendSASLMessage(self.START, self.sasl.mechanism)
initial_message = yield deferToThread(self.sasl.process)
self._sendSASLMessage(self.OK, initial_message)
while True:
status, challenge = yield self._receiveSASLMessage()
if status == self.OK:
response = yield deferToThread(self.sasl.process, challenge)
self._sendSASLMessage(self.OK, response)
elif status == self.COMPLETE:
if not self.sasl.complete:
msg = "The server erroneously indicated that SASL " \
"negotiation was complete"
raise TTransport.TTransportException(msg, message=msg)
else:
break
else:
msg = "Bad SASL negotiation status: %d (%s)" % (status, challenge)
raise TTransport.TTransportException(msg, message=msg)
self._sasl_negotiation_deferred = None
ThriftClientProtocol.connectionMade(self)
def _sendSASLMessage(self, status, body):
if body is None:
body = ""
header = struct.pack(">BI", status, len(body))
self.transport.write(header + body)
def _receiveSASLMessage(self):
self._sasl_negotiation_deferred = defer.Deferred()
self._sasl_negotiation_status = None
return self._sasl_negotiation_deferred
def connectionLost(self, reason=connectionDone):
if self.client:
ThriftClientProtocol.connectionLost(self, reason)
def dataReceived(self, data):
if self._sasl_negotiation_deferred:
# we got a sasl challenge in the format (status, length, challenge)
# save the status, let IntNStringReceiver piece the challenge data together
self._sasl_negotiation_status, = struct.unpack("B", data[0])
ThriftClientProtocol.dataReceived(self, data[1:])
else:
# normal frame, let IntNStringReceiver piece it together
ThriftClientProtocol.dataReceived(self, data)
def stringReceived(self, frame):
if self._sasl_negotiation_deferred:
# the frame is just a SASL challenge
response = (self._sasl_negotiation_status, frame)
self._sasl_negotiation_deferred.callback(response)
else:
# there's a second 4 byte length prefix inside the frame
decoded_frame = self.sasl.unwrap(frame[4:])
ThriftClientProtocol.stringReceived(self, decoded_frame)
class ThriftServerProtocol(basic.Int32StringReceiver):
MAX_LENGTH = 2 ** 31 - 1
def dispatch(self, msg):
self.sendString(msg)
def processError(self, error):
self.transport.loseConnection()
def processOk(self, _, tmo):
msg = tmo.getvalue()
if len(msg) > 0:
self.dispatch(msg)
def stringReceived(self, frame):
tmi = TTransport.TMemoryBuffer(frame)
tmo = TTransport.TMemoryBuffer()
iprot = self.factory.iprot_factory.getProtocol(tmi)
oprot = self.factory.oprot_factory.getProtocol(tmo)
d = self.factory.processor.process(iprot, oprot)
d.addCallbacks(self.processOk, self.processError,
callbackArgs=(tmo,))
class IThriftServerFactory(Interface):
processor = Attribute("Thrift processor")
iprot_factory = Attribute("Input protocol factory")
oprot_factory = Attribute("Output protocol factory")
class IThriftClientFactory(Interface):
client_class = Attribute("Thrift client class")
iprot_factory = Attribute("Input protocol factory")
oprot_factory = Attribute("Output protocol factory")
class ThriftServerFactory(ServerFactory):
implements(IThriftServerFactory)
protocol = ThriftServerProtocol
def __init__(self, processor, iprot_factory, oprot_factory=None):
self.processor = processor
self.iprot_factory = iprot_factory
if oprot_factory is None:
self.oprot_factory = iprot_factory
else:
self.oprot_factory = oprot_factory
class ThriftClientFactory(ClientFactory):
implements(IThriftClientFactory)
protocol = ThriftClientProtocol
def __init__(self, client_class, iprot_factory, oprot_factory=None):
self.client_class = client_class
self.iprot_factory = iprot_factory
if oprot_factory is None:
self.oprot_factory = iprot_factory
else:
self.oprot_factory = oprot_factory
def buildProtocol(self, addr):
p = self.protocol(self.client_class, self.iprot_factory,
self.oprot_factory)
p.factory = self
return p
class ThriftResource(resource.Resource):
allowedMethods = ('POST',)
def __init__(self, processor, inputProtocolFactory,
outputProtocolFactory=None):
resource.Resource.__init__(self)
self.inputProtocolFactory = inputProtocolFactory
if outputProtocolFactory is None:
self.outputProtocolFactory = inputProtocolFactory
else:
self.outputProtocolFactory = outputProtocolFactory
self.processor = processor
def getChild(self, path, request):
return self
def _cbProcess(self, _, request, tmo):
msg = tmo.getvalue()
request.setResponseCode(http.OK)
request.setHeader("content-type", "application/x-thrift")
request.write(msg)
request.finish()
def render_POST(self, request):
request.content.seek(0, 0)
data = request.content.read()
tmi = TTransport.TMemoryBuffer(data)
tmo = TTransport.TMemoryBuffer()
iprot = self.inputProtocolFactory.getProtocol(tmi)
oprot = self.outputProtocolFactory.getProtocol(tmo)
d = self.processor.process(iprot, oprot)
d.addCallback(self._cbProcess, request, tmo)
return server.NOT_DONE_YET

View file

@ -0,0 +1,248 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
"""TZlibTransport provides a compressed transport and transport factory
class, using the python standard library zlib module to implement
data compression.
"""
from __future__ import division
import zlib
from .TTransport import TTransportBase, CReadableTransport
from ..compat import BufferIO
class TZlibTransportFactory(object):
"""Factory transport that builds zlib compressed transports.
This factory caches the last single client/transport that it was passed
and returns the same TZlibTransport object that was created.
This caching means the TServer class will get the _same_ transport
object for both input and output transports from this factory.
(For non-threaded scenarios only, since the cache only holds one object)
The purpose of this caching is to allocate only one TZlibTransport where
only one is really needed (since it must have separate read/write buffers),
and makes the statistics from getCompSavings() and getCompRatio()
easier to understand.
"""
# class scoped cache of last transport given and zlibtransport returned
_last_trans = None
_last_z = None
def getTransport(self, trans, compresslevel=9):
"""Wrap a transport, trans, with the TZlibTransport
compressed transport class, returning a new
transport to the caller.
@param compresslevel: The zlib compression level, ranging
from 0 (no compression) to 9 (best compression). Defaults to 9.
@type compresslevel: int
This method returns a TZlibTransport which wraps the
passed C{trans} TTransport derived instance.
"""
if trans == self._last_trans:
return self._last_z
ztrans = TZlibTransport(trans, compresslevel)
self._last_trans = trans
self._last_z = ztrans
return ztrans
class TZlibTransport(TTransportBase, CReadableTransport):
"""Class that wraps a transport with zlib, compressing writes
and decompresses reads, using the python standard
library zlib module.
"""
# Read buffer size for the python fastbinary C extension,
# the TBinaryProtocolAccelerated class.
DEFAULT_BUFFSIZE = 4096
def __init__(self, trans, compresslevel=9):
"""Create a new TZlibTransport, wrapping C{trans}, another
TTransport derived object.
@param trans: A thrift transport object, i.e. a TSocket() object.
@type trans: TTransport
@param compresslevel: The zlib compression level, ranging
from 0 (no compression) to 9 (best compression). Default is 9.
@type compresslevel: int
"""
self.__trans = trans
self.compresslevel = compresslevel
self.__rbuf = BufferIO()
self.__wbuf = BufferIO()
self._init_zlib()
self._init_stats()
def _reinit_buffers(self):
"""Internal method to initialize/reset the internal StringIO objects
for read and write buffers.
"""
self.__rbuf = BufferIO()
self.__wbuf = BufferIO()
def _init_stats(self):
"""Internal method to reset the internal statistics counters
for compression ratios and bandwidth savings.
"""
self.bytes_in = 0
self.bytes_out = 0
self.bytes_in_comp = 0
self.bytes_out_comp = 0
def _init_zlib(self):
"""Internal method for setting up the zlib compression and
decompression objects.
"""
self._zcomp_read = zlib.decompressobj()
self._zcomp_write = zlib.compressobj(self.compresslevel)
def getCompRatio(self):
"""Get the current measured compression ratios (in,out) from
this transport.
Returns a tuple of:
(inbound_compression_ratio, outbound_compression_ratio)
The compression ratios are computed as:
compressed / uncompressed
E.g., data that compresses by 10x will have a ratio of: 0.10
and data that compresses to half of ts original size will
have a ratio of 0.5
None is returned if no bytes have yet been processed in
a particular direction.
"""
r_percent, w_percent = (None, None)
if self.bytes_in > 0:
r_percent = self.bytes_in_comp / self.bytes_in
if self.bytes_out > 0:
w_percent = self.bytes_out_comp / self.bytes_out
return (r_percent, w_percent)
def getCompSavings(self):
"""Get the current count of saved bytes due to data
compression.
Returns a tuple of:
(inbound_saved_bytes, outbound_saved_bytes)
Note: if compression is actually expanding your
data (only likely with very tiny thrift objects), then
the values returned will be negative.
"""
r_saved = self.bytes_in - self.bytes_in_comp
w_saved = self.bytes_out - self.bytes_out_comp
return (r_saved, w_saved)
def isOpen(self):
"""Return the underlying transport's open status"""
return self.__trans.isOpen()
def open(self):
"""Open the underlying transport"""
self._init_stats()
return self.__trans.open()
def listen(self):
"""Invoke the underlying transport's listen() method"""
self.__trans.listen()
def accept(self):
"""Accept connections on the underlying transport"""
return self.__trans.accept()
def close(self):
"""Close the underlying transport,"""
self._reinit_buffers()
self._init_zlib()
return self.__trans.close()
def read(self, sz):
"""Read up to sz bytes from the decompressed bytes buffer, and
read from the underlying transport if the decompression
buffer is empty.
"""
ret = self.__rbuf.read(sz)
if len(ret) > 0:
return ret
# keep reading from transport until something comes back
while True:
if self.readComp(sz):
break
ret = self.__rbuf.read(sz)
return ret
def readComp(self, sz):
"""Read compressed data from the underlying transport, then
decompress it and append it to the internal StringIO read buffer
"""
zbuf = self.__trans.read(sz)
zbuf = self._zcomp_read.unconsumed_tail + zbuf
buf = self._zcomp_read.decompress(zbuf)
self.bytes_in += len(zbuf)
self.bytes_in_comp += len(buf)
old = self.__rbuf.read()
self.__rbuf = BufferIO(old + buf)
if len(old) + len(buf) == 0:
return False
return True
def write(self, buf):
"""Write some bytes, putting them into the internal write
buffer for eventual compression.
"""
self.__wbuf.write(buf)
def flush(self):
"""Flush any queued up data in the write buffer and ensure the
compression buffer is flushed out to the underlying transport
"""
wout = self.__wbuf.getvalue()
if len(wout) > 0:
zbuf = self._zcomp_write.compress(wout)
self.bytes_out += len(wout)
self.bytes_out_comp += len(zbuf)
else:
zbuf = ''
ztail = self._zcomp_write.flush(zlib.Z_SYNC_FLUSH)
self.bytes_out_comp += len(ztail)
if (len(zbuf) + len(ztail)) > 0:
self.__wbuf = BufferIO()
self.__trans.write(zbuf + ztail)
self.__trans.flush()
@property
def cstringio_buf(self):
"""Implement the CReadableTransport interface"""
return self.__rbuf
def cstringio_refill(self, partialread, reqlen):
"""Implement the CReadableTransport interface for refill"""
retstring = partialread
if reqlen < self.DEFAULT_BUFFSIZE:
retstring += self.read(self.DEFAULT_BUFFSIZE)
while len(retstring) < reqlen:
retstring += self.read(reqlen - len(retstring))
self.__rbuf = BufferIO(retstring)
return self.__rbuf

View file

@ -0,0 +1,20 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
__all__ = ['TTransport', 'TSocket', 'THttpClient', 'TZlibTransport']

View file

@ -0,0 +1,99 @@
#
# licensed to the apache software foundation (asf) under one
# or more contributor license agreements. see the notice file
# distributed with this work for additional information
# regarding copyright ownership. the asf licenses this file
# to you under the apache license, version 2.0 (the
# "license"); you may not use this file except in compliance
# with the license. you may obtain a copy of the license at
#
# http://www.apache.org/licenses/license-2.0
#
# unless required by applicable law or agreed to in writing,
# software distributed under the license is distributed on an
# "as is" basis, without warranties or conditions of any
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
import logging
import sys
from thrift.transport.TTransport import TTransportException
logger = logging.getLogger(__name__)
def legacy_validate_callback(self, cert, hostname):
"""legacy method to validate the peer's SSL certificate, and to check
the commonName of the certificate to ensure it matches the hostname we
used to make this connection. Does not support subjectAltName records
in certificates.
raises TTransportException if the certificate fails validation.
"""
if 'subject' not in cert:
raise TTransportException(
TTransportException.NOT_OPEN,
'No SSL certificate found from %s:%s' % (self.host, self.port))
fields = cert['subject']
for field in fields:
# ensure structure we get back is what we expect
if not isinstance(field, tuple):
continue
cert_pair = field[0]
if len(cert_pair) < 2:
continue
cert_key, cert_value = cert_pair[0:2]
if cert_key != 'commonName':
continue
certhost = cert_value
# this check should be performed by some sort of Access Manager
if certhost == hostname:
# success, cert commonName matches desired hostname
return
else:
raise TTransportException(
TTransportException.UNKNOWN,
'Hostname we connected to "%s" doesn\'t match certificate '
'provided commonName "%s"' % (self.host, certhost))
raise TTransportException(
TTransportException.UNKNOWN,
'Could not validate SSL certificate from host "%s". Cert=%s'
% (hostname, cert))
def _optional_dependencies():
try:
import ipaddress # noqa
logger.debug('ipaddress module is available')
ipaddr = True
except ImportError:
logger.warn('ipaddress module is unavailable')
ipaddr = False
if sys.hexversion < 0x030500F0:
try:
from backports.ssl_match_hostname import match_hostname, __version__ as ver
ver = list(map(int, ver.split('.')))
logger.debug('backports.ssl_match_hostname module is available')
match = match_hostname
if ver[0] * 10 + ver[1] >= 35:
return ipaddr, match
else:
logger.warn('backports.ssl_match_hostname module is too old')
ipaddr = False
except ImportError:
logger.warn('backports.ssl_match_hostname is unavailable')
ipaddr = False
try:
from ssl import match_hostname
logger.debug('ssl.match_hostname is available')
match = match_hostname
except ImportError:
logger.warn('using legacy validation callback')
match = legacy_validate_callback
return ipaddr, match
_match_has_ipaddress, _match_hostname = _optional_dependencies()

View file

@ -0,0 +1,30 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
import glob
import os
import sys
SCRIPT_DIR = os.path.realpath(os.path.dirname(__file__))
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(SCRIPT_DIR)))
for libpath in glob.glob(os.path.join(ROOT_DIR, 'lib', 'py', 'build', 'lib.*')):
if libpath.endswith('-%d.%d' % (sys.version_info[0], sys.version_info[1])):
sys.path.insert(0, libpath)
break

View file

@ -0,0 +1,339 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
import inspect
import logging
import os
import platform
import ssl
import sys
import tempfile
import threading
import unittest
import warnings
from contextlib import contextmanager
import _import_local_thrift # noqa
SCRIPT_DIR = os.path.realpath(os.path.dirname(__file__))
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(SCRIPT_DIR)))
SERVER_PEM = os.path.join(ROOT_DIR, 'test', 'keys', 'server.pem')
SERVER_CERT = os.path.join(ROOT_DIR, 'test', 'keys', 'server.crt')
SERVER_KEY = os.path.join(ROOT_DIR, 'test', 'keys', 'server.key')
CLIENT_CERT_NO_IP = os.path.join(ROOT_DIR, 'test', 'keys', 'client.crt')
CLIENT_KEY_NO_IP = os.path.join(ROOT_DIR, 'test', 'keys', 'client.key')
CLIENT_CERT = os.path.join(ROOT_DIR, 'test', 'keys', 'client_v3.crt')
CLIENT_KEY = os.path.join(ROOT_DIR, 'test', 'keys', 'client_v3.key')
CLIENT_CA = os.path.join(ROOT_DIR, 'test', 'keys', 'CA.pem')
TEST_CIPHERS = 'DES-CBC3-SHA'
class ServerAcceptor(threading.Thread):
def __init__(self, server, expect_failure=False):
super(ServerAcceptor, self).__init__()
self.daemon = True
self._server = server
self._listening = threading.Event()
self._port = None
self._port_bound = threading.Event()
self._client = None
self._client_accepted = threading.Event()
self._expect_failure = expect_failure
frame = inspect.stack(3)[2]
self.name = frame[3]
del frame
def run(self):
self._server.listen()
self._listening.set()
try:
address = self._server.handle.getsockname()
if len(address) > 1:
# AF_INET addresses are 2-tuples (host, port) and AF_INET6 are
# 4-tuples (host, port, ...), but in each case port is in the second slot.
self._port = address[1]
finally:
self._port_bound.set()
try:
self._client = self._server.accept()
except Exception:
logging.exception('error on server side (%s):' % self.name)
if not self._expect_failure:
raise
finally:
self._client_accepted.set()
def await_listening(self):
self._listening.wait()
@property
def port(self):
self._port_bound.wait()
return self._port
@property
def client(self):
self._client_accepted.wait()
return self._client
# Python 2.6 compat
class AssertRaises(object):
def __init__(self, expected):
self._expected = expected
def __enter__(self):
pass
def __exit__(self, exc_type, exc_value, traceback):
if not exc_type or not issubclass(exc_type, self._expected):
raise Exception('fail')
return True
class TSSLSocketTest(unittest.TestCase):
def _server_socket(self, **kwargs):
return TSSLServerSocket(port=0, **kwargs)
@contextmanager
def _connectable_client(self, server, expect_failure=False, path=None, **client_kwargs):
acc = ServerAcceptor(server, expect_failure)
try:
acc.start()
acc.await_listening()
host, port = ('localhost', acc.port) if path is None else (None, None)
client = TSSLSocket(host, port, unix_socket=path, **client_kwargs)
yield acc, client
finally:
if acc.client:
acc.client.close()
server.close()
def _assert_connection_failure(self, server, path=None, **client_args):
logging.disable(logging.CRITICAL)
try:
with self._connectable_client(server, True, path=path, **client_args) as (acc, client):
# We need to wait for a connection failure, but not too long. 20ms is a tunable
# compromise between test speed and stability
client.setTimeout(20)
with self._assert_raises(TTransportException):
client.open()
self.assertTrue(acc.client is None)
finally:
logging.disable(logging.NOTSET)
def _assert_raises(self, exc):
if sys.hexversion >= 0x020700F0:
return self.assertRaises(exc)
else:
return AssertRaises(exc)
def _assert_connection_success(self, server, path=None, **client_args):
with self._connectable_client(server, path=path, **client_args) as (acc, client):
client.open()
try:
self.assertTrue(acc.client is not None)
finally:
client.close()
# deprecated feature
def test_deprecation(self):
with warnings.catch_warnings(record=True) as w:
warnings.filterwarnings('always', category=DeprecationWarning, module=self.__module__)
TSSLSocket('localhost', 0, validate=True, ca_certs=SERVER_CERT)
self.assertEqual(len(w), 1)
with warnings.catch_warnings(record=True) as w:
warnings.filterwarnings('always', category=DeprecationWarning, module=self.__module__)
# Deprecated signature
# def __init__(self, host='localhost', port=9090, validate=True, ca_certs=None, keyfile=None, certfile=None, unix_socket=None, ciphers=None):
TSSLSocket('localhost', 0, True, SERVER_CERT, CLIENT_KEY, CLIENT_CERT, None, TEST_CIPHERS)
self.assertEqual(len(w), 7)
with warnings.catch_warnings(record=True) as w:
warnings.filterwarnings('always', category=DeprecationWarning, module=self.__module__)
# Deprecated signature
# def __init__(self, host=None, port=9090, certfile='cert.pem', unix_socket=None, ciphers=None):
TSSLServerSocket(None, 0, SERVER_PEM, None, TEST_CIPHERS)
self.assertEqual(len(w), 3)
# deprecated feature
def test_set_cert_reqs_by_validate(self):
with warnings.catch_warnings(record=True) as w:
warnings.filterwarnings('always', category=DeprecationWarning, module=self.__module__)
c1 = TSSLSocket('localhost', 0, validate=True, ca_certs=SERVER_CERT)
self.assertEqual(c1.cert_reqs, ssl.CERT_REQUIRED)
c1 = TSSLSocket('localhost', 0, validate=False)
self.assertEqual(c1.cert_reqs, ssl.CERT_NONE)
self.assertEqual(len(w), 2)
# deprecated feature
def test_set_validate_by_cert_reqs(self):
with warnings.catch_warnings(record=True) as w:
warnings.filterwarnings('always', category=DeprecationWarning, module=self.__module__)
c1 = TSSLSocket('localhost', 0, cert_reqs=ssl.CERT_NONE)
self.assertFalse(c1.validate)
c2 = TSSLSocket('localhost', 0, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT)
self.assertTrue(c2.validate)
c3 = TSSLSocket('localhost', 0, cert_reqs=ssl.CERT_OPTIONAL, ca_certs=SERVER_CERT)
self.assertTrue(c3.validate)
self.assertEqual(len(w), 3)
def test_unix_domain_socket(self):
if platform.system() == 'Windows':
print('skipping test_unix_domain_socket')
return
fd, path = tempfile.mkstemp()
os.close(fd)
try:
server = self._server_socket(unix_socket=path, keyfile=SERVER_KEY, certfile=SERVER_CERT)
self._assert_connection_success(server, path=path, cert_reqs=ssl.CERT_NONE)
finally:
os.unlink(path)
def test_server_cert(self):
server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT)
self._assert_connection_success(server, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT)
server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT)
# server cert not in ca_certs
self._assert_connection_failure(server, cert_reqs=ssl.CERT_REQUIRED, ca_certs=CLIENT_CERT)
server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT)
self._assert_connection_success(server, cert_reqs=ssl.CERT_NONE)
def test_set_server_cert(self):
server = self._server_socket(keyfile=SERVER_KEY, certfile=CLIENT_CERT)
with self._assert_raises(Exception):
server.certfile = 'foo'
with self._assert_raises(Exception):
server.certfile = None
server.certfile = SERVER_CERT
self._assert_connection_success(server, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT)
def test_client_cert(self):
server = self._server_socket(
cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY,
certfile=SERVER_CERT, ca_certs=CLIENT_CERT)
self._assert_connection_failure(server, cert_reqs=ssl.CERT_NONE, certfile=SERVER_CERT, keyfile=SERVER_KEY)
server = self._server_socket(
cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY,
certfile=SERVER_CERT, ca_certs=CLIENT_CA)
self._assert_connection_failure(server, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT_NO_IP, keyfile=CLIENT_KEY_NO_IP)
server = self._server_socket(
cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY,
certfile=SERVER_CERT, ca_certs=CLIENT_CA)
self._assert_connection_success(server, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT, keyfile=CLIENT_KEY)
server = self._server_socket(
cert_reqs=ssl.CERT_OPTIONAL, keyfile=SERVER_KEY,
certfile=SERVER_CERT, ca_certs=CLIENT_CA)
self._assert_connection_success(server, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT, keyfile=CLIENT_KEY)
def test_ciphers(self):
server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ciphers=TEST_CIPHERS)
self._assert_connection_success(server, ca_certs=SERVER_CERT, ciphers=TEST_CIPHERS)
if not TSSLSocket._has_ciphers:
# unittest.skip is not available for Python 2.6
print('skipping test_ciphers')
return
server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT)
self._assert_connection_failure(server, ca_certs=SERVER_CERT, ciphers='NULL')
server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ciphers=TEST_CIPHERS)
self._assert_connection_failure(server, ca_certs=SERVER_CERT, ciphers='NULL')
def test_ssl2_and_ssl3_disabled(self):
if not hasattr(ssl, 'PROTOCOL_SSLv3'):
print('PROTOCOL_SSLv3 is not available')
else:
server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT)
self._assert_connection_failure(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv3)
server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv3)
self._assert_connection_failure(server, ca_certs=SERVER_CERT)
if not hasattr(ssl, 'PROTOCOL_SSLv2'):
print('PROTOCOL_SSLv2 is not available')
else:
server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT)
self._assert_connection_failure(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv2)
server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv2)
self._assert_connection_failure(server, ca_certs=SERVER_CERT)
def test_newer_tls(self):
if not TSSLSocket._has_ssl_context:
# unittest.skip is not available for Python 2.6
print('skipping test_newer_tls')
return
if not hasattr(ssl, 'PROTOCOL_TLSv1_2'):
print('PROTOCOL_TLSv1_2 is not available')
else:
server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2)
self._assert_connection_success(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2)
if not hasattr(ssl, 'PROTOCOL_TLSv1_1'):
print('PROTOCOL_TLSv1_1 is not available')
else:
server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1)
self._assert_connection_success(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1)
if not hasattr(ssl, 'PROTOCOL_TLSv1_1') or not hasattr(ssl, 'PROTOCOL_TLSv1_2'):
print('PROTOCOL_TLSv1_1 and/or PROTOCOL_TLSv1_2 is not available')
else:
server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2)
self._assert_connection_failure(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1)
def test_ssl_context(self):
if not TSSLSocket._has_ssl_context:
# unittest.skip is not available for Python 2.6
print('skipping test_ssl_context')
return
server_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
server_context.load_cert_chain(SERVER_CERT, SERVER_KEY)
server_context.load_verify_locations(CLIENT_CA)
server_context.verify_mode = ssl.CERT_REQUIRED
server = self._server_socket(ssl_context=server_context)
client_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
client_context.load_cert_chain(CLIENT_CERT, CLIENT_KEY)
client_context.load_verify_locations(SERVER_CERT)
client_context.verify_mode = ssl.CERT_REQUIRED
self._assert_connection_success(server, ssl_context=client_context)
if __name__ == '__main__':
logging.basicConfig(level=logging.WARN)
from thrift.transport.TSSLSocket import TSSLSocket, TSSLServerSocket
from thrift.transport.TTransport import TTransportException
unittest.main()

View file

@ -0,0 +1,50 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
import sys
import unittest
import _import_local_thrift # noqa
from thrift.protocol.TJSONProtocol import TJSONProtocol
from thrift.transport import TTransport
#
# In order to run the test under Windows. We need to create symbolic link
# name 'thrift' to '../src' folder by using:
#
# mklink /D thrift ..\src
#
class TestJSONString(unittest.TestCase):
def test_escaped_unicode_string(self):
unicode_json = b'"hello \\u0e01\\u0e02\\u0e03\\ud835\\udcab\\udb40\\udc70 unicode"'
unicode_text = u'hello \u0e01\u0e02\u0e03\U0001D4AB\U000E0070 unicode'
buf = TTransport.TMemoryBuffer(unicode_json)
transport = TTransport.TBufferedTransportFactory().getTransport(buf)
protocol = TJSONProtocol(transport)
if sys.version_info[0] == 2:
unicode_text = unicode_text.encode('utf8')
self.assertEqual(protocol.readString(), unicode_text)
if __name__ == '__main__':
unittest.main()