Moving from govendor to dep, updated dependencies (#48)

* Moving from govendor to dep.

* Making the pull request template more friendly.

* Fixing akward space in PR template.

* goimports run on whole project using ` goimports -w $(find . -type f -name '*.go' -not -path "./vendor/*" -not -path "./gen-go/*")`

source of command: https://gist.github.com/bgentry/fd1ffef7dbde01857f66
This commit is contained in:
Renan DelValle 2018-01-07 13:13:47 -08:00 committed by GitHub
parent 9631aa3aab
commit 8d445c1c77
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2186 changed files with 400410 additions and 352 deletions

194
vendor/git.apache.org/thrift.git/lib/d/Makefile.am generated vendored Normal file
View file

@ -0,0 +1,194 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
AUTOMAKE_OPTIONS = serial-tests
SUBDIRS = .
if WITH_TESTS
SUBDIRS += test
endif
#
# Enumeration of all the public and private modules.
#
# We unconditionally install all of them, even if libevent or OpenSSL are
# not available, but build the respective libraries only if the Deimos headers
# could be found.
#
d_thriftmodules = $(addprefix thrift/, base)
d_thriftdir = $(D_IMPORT_PREFIX)/thrift
d_thrift_DATA = $(addprefix src/, $(addsuffix .d, $(d_thriftmodules)))
d_asyncmodules = $(addprefix thrift/async/, base libevent socket ssl)
d_asyncdir = $(d_thriftdir)/async
d_async_DATA = $(addprefix src/, $(addsuffix .d, $(d_asyncmodules)))
d_codegenmodules = $(addprefix thrift/codegen/, async_client \
async_client_pool base client client_pool processor)
#d_codegenmodules = $(addprefix thrift/codegen/, async_client \
# async_client_pool base client client_pool idlgen processor)
d_codegendir = $(d_thriftdir)/codegen
d_codegen_DATA = $(addprefix src/, $(addsuffix .d, $(d_codegenmodules)))
d_protocolmodules = $(addprefix thrift/protocol/, base binary compact json \
processor)
d_protocoldir = $(d_thriftdir)/protocol
d_protocol_DATA = $(addprefix src/, $(addsuffix .d, $(d_protocolmodules)))
d_servermodules = $(addprefix thrift/server/, base simple nonblocking \
taskpool threaded)
d_serverdir = $(d_thriftdir)/server
d_server_DATA = $(addprefix src/, $(addsuffix .d, $(d_servermodules)))
d_servertransportmodules = $(addprefix thrift/server/transport/, base socket ssl)
d_servertransportdir = $(d_thriftdir)/server/transport
d_servertransport_DATA = $(addprefix src/, $(addsuffix .d, \
$(d_servertransportmodules)))
d_transportmodules = $(addprefix thrift/transport/, base buffered file \
framed http memory piped range socket ssl zlib)
d_transportdir = $(d_thriftdir)/transport
d_transport_DATA = $(addprefix src/, $(addsuffix .d, $(d_transportmodules)))
d_utilmodules = $(addprefix thrift/util/, awaitable cancellation future \
hashset)
d_utildir = $(d_thriftdir)/util
d_util_DATA = $(addprefix src/, $(addsuffix .d, $(d_utilmodules)))
d_internalmodules = $(addprefix thrift/internal/, algorithm codegen ctfe \
endian resource_pool socket ssl ssl_bio traits)
d_internaldir = $(d_thriftdir)/internal
d_internal_DATA = $(addprefix src/, $(addsuffix .d, $(d_internalmodules)))
d_testmodules = $(addprefix thrift/internal/test/, protocol server)
d_testdir = $(d_internaldir)/test
d_test_DATA = $(addprefix src/, $(addsuffix .d, $(d_testmodules)))
d_publicmodules = $(d_thriftmodules) $(d_asyncmodules) \
$(d_codegenmodules) $(d_protocolmodules) $(d_servermodules) \
$(d_servertransportmodules) $(d_transportmodules) $(d_utilmodules)
d_publicsources = $(addprefix src/, $(addsuffix .d, $(d_publicmodules)))
d_modules = $(d_publicmodules) $(d_internalmodules) $(d_testmodules)
# List modules with external dependencies and remove them from the main list
d_libevent_dependent_modules = thrift/async/libevent thrift/server/nonblocking
d_openssl_dependent_modules = thrift/async/ssl thrift/internal/ssl \
thrift/internal/ssl_bio thrift/transport/ssl thrift/server/transport/ssl
d_main_modules = $(filter-out $(d_libevent_dependent_modules) \
$(d_openssl_dependent_modules),$(d_modules))
d_lib_flags = -w -wi -Isrc -lib
all_targets =
#
# libevent-dependent modules.
#
if HAVE_DEIMOS_EVENT2
$(D_EVENT_LIB_NAME): $(addprefix src/, $(addsuffix .d, $(d_libevent_dependent_modules)))
$(DMD) -of$(D_EVENT_LIB_NAME) $(d_lib_flags) $^
all_targets += $(D_EVENT_LIB_NAME)
endif
#
# OpenSSL-dependent modules.
#
if HAVE_DEIMOS_OPENSSL
$(D_SSL_LIB_NAME): $(addprefix src/, $(addsuffix .d, $(d_openssl_dependent_modules)))
$(DMD) -of$(D_SSL_LIB_NAME) $(d_lib_flags) $^
all_targets += $(D_SSL_LIB_NAME)
endif
#
# Main library target.
#
$(D_LIB_NAME): $(addprefix src/, $(addsuffix .d, $(d_main_modules)))
$(DMD) -of$(D_LIB_NAME) $(d_lib_flags) $^
all_targets += $(D_LIB_NAME)
#
# Documentation target (requires Dil).
#
docs: $(d_publicsources) src/thrift/index.d
dil ddoc docs -hl --kandil $^
#
# Hook custom library targets into the automake all/install targets.
#
all-local: $(all_targets)
install-exec-local:
$(INSTALL_PROGRAM) $(all_targets) $(DESTDIR)$(libdir)
clean-local:
$(RM) -rf docs $(D_LIB_NAME) $(D_EVENT_LIB_NAME) $(D_SSL_LIB_NAME) unittest
#
# Unit tests (built both in debug and release mode).
#
d_test_flags = -unittest -w -wi -I$(top_srcdir)/lib/d/src
# There just must be some way to reassign a variable without warnings in
# Automake...
d_test_modules__ = $(d_modules)
if WITH_D_EVENT_TESTS
d_test_flags += $(DMD_LIBEVENT_FLAGS)
d_test_modules_ = $(d_test_modules__)
else
d_test_modules_ = $(filter-out $(d_libevent_dependent_modules), $(d_test_modules__))
endif
if WITH_D_SSL_TESTS
d_test_flags += $(DMD_OPENSSL_FLAGS)
d_test_modules = $(d_test_modules_)
else
d_test_modules = $(filter-out $(d_openssl_dependent_modules), $(d_test_modules_))
endif
unittest/emptymain.d: unittest/.directory
@echo 'void main(){}' >$@
unittest/.directory:
mkdir -p unittest || exists unittest
touch $@
unittest/debug/%: src/%.d $(all_targets) unittest/emptymain.d
$(DMD) -gc -of$(subst /,$(DMD_OF_DIRSEP),$@) $(d_test_flags) $^
unittest/release/%: src/%.d $(all_targets) unittest/emptymain.d
$(DMD) -O -release -of$(subst /,$(DMD_OF_DIRSEP),$@) $(d_test_flags) $^
TESTS = $(addprefix unittest/debug/, $(d_test_modules)) \
$(addprefix unittest/release/, $(d_test_modules))
precross: all-local
$(MAKE) -C test precross
EXTRA_DIST = \
src \
test \
README.md

58
vendor/git.apache.org/thrift.git/lib/d/README.md generated vendored Normal file
View file

@ -0,0 +1,58 @@
Thrift D Software Library
=========================
License
-------
Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
Testing
-------
D support in Thrift is covered by two sets of tests: first,
the unit test blocks contained in the D source files, and
second, the more extensive testing applications in the test/
subdirectory, which also make use of the Thrift compiler.
Both are built when running "make check", but only the
unit tests are immediately run, however the separate test
cases typically run longer or require manual intervention.
It might also be prudent to run the independent tests,
which typically consist of a server and a client part,
against the other language implementations.
To build the unit tests on Windows, the easiest way might
be to manually create a file containing an empty main() and
invoke the compiler by running the following in the src/
directory (PowerShell syntax):
dmd -ofunittest -unittest -w $(dir -r -filter '*.d' -name)
If you want to run the test clients/servers in OpenSSL
mode, you have to provide »server-private-key.pem« and
»server-certificate.pem« files in the directory the server
executable resides in, and a »trusted-ca-certificate.pem«
file for the client. The easiest way is to generate a new
self signed certificate using the provided config file
(openssl.test.cnf):
openssl req -new -x509 -nodes -config openssl.test.cnf \
-out server-certificate.pem
cat server-certificate.pem > trusted-ca-certificate.pem
This steps are also performed automatically by the
Autotools build system if the files are not present.

View file

@ -0,0 +1 @@
Please follow [General Coding Standards](/doc/coding_standards.md)

View file

@ -0,0 +1,228 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/**
* Defines the interface used for client-side handling of asynchronous
* I/O operations, based on coroutines.
*
* The main piece of the »client side« (e.g. for TAsyncClient users) of the
* API is TFuture, which represents an asynchronously executed operation,
* which can have a return value, throw exceptions, and which can be waited
* upon.
*
* On the »implementation side«, the idea is that by using a TAsyncTransport
* instead of a normal TTransport and executing the work through a
* TAsyncManager, the same code as for synchronous I/O can be used for
* asynchronous operation as well, for example:
*
* ---
* auto socket = new TAsyncSocket(someTAsyncSocketManager(), host, port);
* // …
* socket.asyncManager.execute(socket, {
* SomeThriftStruct s;
*
* // Waiting for socket I/O will not block an entire thread but cause
* // the async manager to execute another task in the meantime, because
* // we are using TAsyncSocket instead of TSocket.
* s.read(socket);
*
* // Do something with s, e.g. set a TPromise result to it.
* writeln(s);
* });
* ---
*/
module thrift.async.base;
import core.time : Duration, dur;
import std.socket/+ : Socket+/; // DMD @@BUG314@@
import thrift.base;
import thrift.transport.base;
import thrift.util.cancellation;
/**
* Manages one or more asynchronous transport resources (e.g. sockets in the
* case of TAsyncSocketManager) and allows work items to be submitted for them.
*
* Implementations will typically run one or more background threads for
* executing the work, which is one of the reasons for a TAsyncManager to be
* used. Each work item is run in its own fiber and is expected to yield() away
* while waiting for time-consuming operations.
*
* The second important purpose of TAsyncManager is to serialize access to
* the transport resources without taking care of that, e.g. issuing multiple
* RPC calls over the same connection in rapid succession would likely lead to
* more than one request being written at the same time, causing only garbage
* to arrive at the remote end.
*
* All methods are thread-safe.
*/
interface TAsyncManager {
/**
* Submits a work item to be executed asynchronously.
*
* Access to asnyc transports is serialized if two work items associated
* with the same transport are submitted, the second delegate will not be
* invoked until the first has returned, even it the latter context-switches
* away (because it is waiting for I/O) and the async manager is idle
* otherwise.
*
* Optionally, a TCancellation instance can be specified. If present,
* triggering it will be considered a request to cancel the work item, if it
* is still waiting for the associated transport to become available.
* Delegates which are already being processed (i.e. waiting for I/O) are not
* affected because this would bring the connection into an undefined state
* (as probably half-written request or a half-read response would be left
* behind).
*
* Params:
* transport = The TAsyncTransport the work delegate will operate on. Must
* be associated with this TAsyncManager instance.
* work = The operations to execute on the given transport. Must never
* throw, errors should be handled in another way. nothrow semantics are
* difficult to enforce in combination with fibres though, so currently
* exceptions are just swallowed by TAsyncManager implementations.
* cancellation = If set, can be used to request cancellatinon of this work
* item if it is still waiting to be executed.
*
* Note: The work item will likely be executed in a different thread, so make
* sure the code it relies on is thread-safe. An exception are the async
* transports themselves, to which access is serialized as noted above.
*/
void execute(TAsyncTransport transport, void delegate() work,
TCancellation cancellation = null
) in {
assert(transport.asyncManager is this,
"The given transport must be associated with this TAsyncManager.");
}
/**
* Submits a delegate to be executed after a certain amount of time has
* passed.
*
* The actual amount of time elapsed can be higher if the async manager
* instance is busy and thus should not be relied on. The
*
* Params:
* duration = The amount of time to wait before starting to execute the
* work delegate.
* work = The code to execute after the specified amount of time has passed.
*
* Example:
* ---
* // A very basic example usually, the actuall work item would enqueue
* // some async transport operation.
* auto asyncMangager = someAsyncManager();
*
* TFuture!int calculate() {
* // Create a promise and asynchronously set its value after three
* // seconds have passed.
* auto promise = new TPromise!int;
* asyncManager.delay(dur!"seconds"(3), {
* promise.succeed(42);
* });
*
* // Immediately return it to the caller.
* return promise;
* }
*
* // This will wait until the result is available and then print it.
* writeln(calculate().waitGet());
* ---
*/
void delay(Duration duration, void delegate() work);
/**
* Shuts down all background threads or other facilities that might have
* been started in order to execute work items. This function is typically
* called during program shutdown.
*
* If there are still tasks to be executed when the timeout expires, any
* currently executed work items will never receive any notifications
* for async transports managed by this instance, queued work items will
* be silently dropped, and implementations are allowed to leak resources.
*
* Params:
* waitFinishTimeout = If positive, waits for all work items to be
* finished for the specified amount of time, if negative, waits for
* completion without ever timing out, if zero, immediately shuts down
* the background facilities.
*/
bool stop(Duration waitFinishTimeout = dur!"hnsecs"(-1));
}
/**
* A TTransport which uses a TAsyncManager to schedule non-blocking operations.
*
* The actual type of device is not specified; typically, implementations will
* depend on an interface derived from TAsyncManager to be notified of changes
* in the transport state.
*
* The peeking, reading, writing and flushing methods must always be called
* from within the associated async manager.
*/
interface TAsyncTransport : TTransport {
/**
* The TAsyncManager associated with this transport.
*/
TAsyncManager asyncManager() @property;
}
/**
* A TAsyncManager providing notificiations for socket events.
*/
interface TAsyncSocketManager : TAsyncManager {
/**
* Adds a listener that is triggered once when an event of the specified type
* occurs, and removed afterwards.
*
* Params:
* socket = The socket to listen for events at.
* eventType = The type of the event to listen for.
* timeout = The period of time after which the listener will be called
* with TAsyncEventReason.TIMED_OUT if no event happened.
* listener = The delegate to call when an event happened.
*/
void addOneshotListener(Socket socket, TAsyncEventType eventType,
Duration timeout, TSocketEventListener listener);
/// Ditto
void addOneshotListener(Socket socket, TAsyncEventType eventType,
TSocketEventListener listener);
}
/**
* Types of events that can happen for an asynchronous transport.
*/
enum TAsyncEventType {
READ, /// New data became available to read.
WRITE /// The transport became ready to be written to.
}
/**
* The type of the delegates used to register socket event handlers.
*/
alias void delegate(TAsyncEventReason callReason) TSocketEventListener;
/**
* The reason a listener was called.
*/
enum TAsyncEventReason : byte {
NORMAL, /// The event listened for was triggered normally.
TIMED_OUT /// A timeout for the event was set, and it expired.
}

View file

@ -0,0 +1,461 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.async.libevent;
import core.atomic;
import core.time : Duration, dur;
import core.exception : onOutOfMemoryError;
import core.memory : GC;
import core.thread : Fiber, Thread;
import core.sync.condition;
import core.sync.mutex;
import core.stdc.stdlib : free, malloc;
import deimos.event2.event;
import std.array : empty, front, popFront;
import std.conv : text, to;
import std.exception : enforce;
import std.socket : Socket, socketPair;
import thrift.base;
import thrift.async.base;
import thrift.internal.socket;
import thrift.internal.traits;
import thrift.util.cancellation;
// To avoid DMD @@BUG6395@@.
import thrift.internal.algorithm;
/**
* A TAsyncManager implementation based on libevent.
*
* The libevent loop for handling non-blocking sockets is run in a background
* thread, which is lazily spawned. The thread is not daemonized to avoid
* crashes on program shutdown, it is only stopped when the manager instance
* is destroyed. So, to ensure a clean program teardown, either make sure this
* instance gets destroyed (e.g. by using scope), or manually call stop() at
* the end.
*/
class TLibeventAsyncManager : TAsyncSocketManager {
this() {
eventBase_ = event_base_new();
// Set up the socket pair for transferring control messages to the event
// loop.
auto pair = socketPair();
controlSendSocket_ = pair[0];
controlReceiveSocket_ = pair[1];
controlReceiveSocket_.blocking = false;
// Register an event for receiving control messages.
controlReceiveEvent_ = event_new(eventBase_, controlReceiveSocket_.handle,
EV_READ | EV_PERSIST | EV_ET, assumeNothrow(&controlMsgReceiveCallback),
cast(void*)this);
event_add(controlReceiveEvent_, null);
queuedCountMutex_ = new Mutex;
zeroQueuedCondition_ = new Condition(queuedCountMutex_);
}
~this() {
// stop() should be safe to call, because either we don't have a worker
// thread running and it is a no-op anyway, or it is guaranteed to be
// still running (blocked in event_base_loop), and thus guaranteed not to
// be garbage collected yet.
stop(dur!"hnsecs"(0));
event_free(controlReceiveEvent_);
event_base_free(eventBase_);
eventBase_ = null;
}
override void execute(TAsyncTransport transport, Work work,
TCancellation cancellation = null
) {
if (cancellation && cancellation.triggered) return;
// Keep track that there is a new work item to be processed.
incrementQueuedCount();
ensureWorkerThreadRunning();
// We should be able to send the control message as a whole we currently
// assume to be able to receive it at once as well. If this proves to be
// unstable (e.g. send could possibly return early if the receiving buffer
// is full and the blocking call gets interrupted by a signal), it could
// be changed to a more sophisticated scheme.
// Make sure the delegate context doesn't get GCd while the work item is
// on the wire.
GC.addRoot(work.ptr);
// Send work message.
sendControlMsg(ControlMsg(MsgType.WORK, work, transport));
if (cancellation) {
cancellation.triggering.addCallback({
sendControlMsg(ControlMsg(MsgType.CANCEL, work, transport));
});
}
}
override void delay(Duration duration, void delegate() work) {
incrementQueuedCount();
ensureWorkerThreadRunning();
const tv = toTimeval(duration);
// DMD @@BUG@@: Cannot deduce T to void delegate() here.
registerOneshotEvent!(void delegate())(
-1, 0, assumeNothrow(&delayCallback), &tv,
{
work();
decrementQueuedCount();
}
);
}
override bool stop(Duration waitFinishTimeout = dur!"hnsecs"(-1)) {
bool cleanExit = true;
synchronized (this) {
if (workerThread_) {
synchronized (queuedCountMutex_) {
if (waitFinishTimeout > dur!"hnsecs"(0)) {
if (queuedCount_ > 0) {
zeroQueuedCondition_.wait(waitFinishTimeout);
}
} else if (waitFinishTimeout < dur!"hnsecs"(0)) {
while (queuedCount_ > 0) zeroQueuedCondition_.wait();
} else {
// waitFinishTimeout is zero, immediately exit in all cases.
}
cleanExit = (queuedCount_ == 0);
}
event_base_loopbreak(eventBase_);
sendControlMsg(ControlMsg(MsgType.SHUTDOWN));
workerThread_.join();
workQueues_ = null;
// We have nuked all currently enqueued items, so set the count to
// zero. This is safe to do without locking, since the worker thread
// is down.
queuedCount_ = 0;
atomicStore(*(cast(shared)&workerThread_), cast(shared(Thread))null);
}
}
return cleanExit;
}
override void addOneshotListener(Socket socket, TAsyncEventType eventType,
TSocketEventListener listener
) {
addOneshotListenerImpl(socket, eventType, null, listener);
}
override void addOneshotListener(Socket socket, TAsyncEventType eventType,
Duration timeout, TSocketEventListener listener
) {
if (timeout <= dur!"hnsecs"(0)) {
addOneshotListenerImpl(socket, eventType, null, listener);
} else {
// This is not really documented well, but libevent does not require to
// keep the timeval around after the event was added.
auto tv = toTimeval(timeout);
addOneshotListenerImpl(socket, eventType, &tv, listener);
}
}
private:
alias void delegate() Work;
void addOneshotListenerImpl(Socket socket, TAsyncEventType eventType,
const(timeval)* timeout, TSocketEventListener listener
) {
registerOneshotEvent(socket.handle, libeventEventType(eventType),
assumeNothrow(&socketCallback), timeout, listener);
}
void registerOneshotEvent(T)(evutil_socket_t fd, short type,
event_callback_fn callback, const(timeval)* timeout, T payload
) {
// Create a copy of the payload on the C heap.
auto payloadMem = malloc(payload.sizeof);
if (!payloadMem) onOutOfMemoryError();
(cast(T*)payloadMem)[0 .. 1] = payload;
GC.addRange(payloadMem, payload.sizeof);
auto result = event_base_once(eventBase_, fd, type, callback,
payloadMem, timeout);
// Assuming that we didn't get our arguments wrong above, the only other
// situation in which event_base_once can fail is when it can't allocate
// memory.
if (result != 0) onOutOfMemoryError();
}
enum MsgType : ubyte {
SHUTDOWN,
WORK,
CANCEL
}
struct ControlMsg {
MsgType type;
Work work;
TAsyncTransport transport;
}
/**
* Starts the worker thread if it is not already running.
*/
void ensureWorkerThreadRunning() {
// Technically, only half barriers would be required here, but adding the
// argument seems to trigger a DMD template argument deduction @@BUG@@.
if (!atomicLoad(*(cast(shared)&workerThread_))) {
synchronized (this) {
if (!workerThread_) {
auto thread = new Thread({ event_base_loop(eventBase_, 0); });
thread.start();
atomicStore(*(cast(shared)&workerThread_), cast(shared)thread);
}
}
}
}
/**
* Sends a control message to the worker thread.
*/
void sendControlMsg(const(ControlMsg) msg) {
auto result = controlSendSocket_.send((&msg)[0 .. 1]);
enum size = msg.sizeof;
enforce(result == size, new TException(text(
"Sending control message of type ", msg.type, " failed (", result,
" bytes instead of ", size, " transmitted).")));
}
/**
* Receives messages from the control message socket and acts on them. Called
* from the worker thread.
*/
void receiveControlMsg() {
// Read as many new work items off the socket as possible (at least one
// should be available, as we got notified by libevent).
ControlMsg msg;
ptrdiff_t bytesRead;
while (true) {
bytesRead = controlReceiveSocket_.receive(cast(ubyte[])((&msg)[0 .. 1]));
if (bytesRead < 0) {
auto errno = getSocketErrno();
if (errno != WOULD_BLOCK_ERRNO) {
logError("Reading control message, some work item will possibly " ~
"never be executed: %s", socketErrnoString(errno));
}
}
if (bytesRead != msg.sizeof) break;
// Everything went fine, we received a new control message.
final switch (msg.type) {
case MsgType.SHUTDOWN:
// The message was just intended to wake us up for shutdown.
break;
case MsgType.CANCEL:
// When processing a cancellation, we must not touch the first item,
// since it is already being processed.
auto queue = workQueues_[msg.transport];
if (queue.length > 0) {
workQueues_[msg.transport] = [queue[0]] ~
removeEqual(queue[1 .. $], msg.work);
}
break;
case MsgType.WORK:
// Now that the work item is back in the D world, we don't need the
// extra GC root for the context pointer anymore (see execute()).
GC.removeRoot(msg.work.ptr);
// Add the work item to the queue and execute it.
auto queue = msg.transport in workQueues_;
if (queue is null || (*queue).empty) {
// If the queue is empty, add the new work item to the queue as well,
// but immediately start executing it.
workQueues_[msg.transport] = [msg.work];
executeWork(msg.transport, msg.work);
} else {
(*queue) ~= msg.work;
}
break;
}
}
// If the last read was successful, but didn't read enough bytes, we got
// a problem.
if (bytesRead > 0) {
logError("Unexpected partial control message read (%s byte(s) " ~
"instead of %s), some work item will possibly never be executed.",
bytesRead, msg.sizeof);
}
}
/**
* Executes the given work item and all others enqueued for the same
* transport in a new fiber. Called from the worker thread.
*/
void executeWork(TAsyncTransport transport, Work work) {
(new Fiber({
auto item = work;
while (true) {
try {
// Execute the actual work. It will possibly add listeners to the
// event loop and yield away if it has to wait for blocking
// operations. It is quite possible that another fiber will modify
// the work queue for the current transport.
item();
} catch (Exception e) {
// This should never happen, just to be sure the worker thread
// doesn't stop working in mysterious ways because of an unhandled
// exception.
logError("Exception thrown by work item: %s", e);
}
// Remove the item from the work queue.
// Note: Due to the value semantics of array slices, we have to
// re-lookup this on every iteration. This could be solved, but I'd
// rather replace this directly with a queue type once one becomes
// available in Phobos.
auto queue = workQueues_[transport];
assert(queue.front == item);
queue.popFront();
workQueues_[transport] = queue;
// Now that the work item is done, no longer count it as queued.
decrementQueuedCount();
if (queue.empty) break;
// If the queue is not empty, execute the next waiting item.
item = queue.front;
}
})).call();
}
/**
* Increments the amount of queued items.
*/
void incrementQueuedCount() {
synchronized (queuedCountMutex_) {
++queuedCount_;
}
}
/**
* Decrements the amount of queued items.
*/
void decrementQueuedCount() {
synchronized (queuedCountMutex_) {
assert(queuedCount_ > 0);
--queuedCount_;
if (queuedCount_ == 0) {
zeroQueuedCondition_.notifyAll();
}
}
}
static extern(C) void controlMsgReceiveCallback(evutil_socket_t, short,
void *managerThis
) {
(cast(TLibeventAsyncManager)managerThis).receiveControlMsg();
}
static extern(C) void socketCallback(evutil_socket_t, short flags,
void *arg
) {
auto reason = (flags & EV_TIMEOUT) ? TAsyncEventReason.TIMED_OUT :
TAsyncEventReason.NORMAL;
(*(cast(TSocketEventListener*)arg))(reason);
GC.removeRange(arg);
destroy(arg);
free(arg);
}
static extern(C) void delayCallback(evutil_socket_t, short flags,
void *arg
) {
assert(flags & EV_TIMEOUT);
(*(cast(void delegate()*)arg))();
GC.removeRange(arg);
destroy(arg);
free(arg);
}
Thread workerThread_;
event_base* eventBase_;
/// The socket used for receiving new work items in the event loop. Paired
/// with controlSendSocket_. Invalid (i.e. TAsyncWorkItem.init) items are
/// ignored and can be used to wake up the worker thread.
Socket controlReceiveSocket_;
event* controlReceiveEvent_;
/// The socket used to send new work items to the event loop. It is
/// expected that work items can always be read at once from it, i.e. that
/// there will never be short reads.
Socket controlSendSocket_;
/// Queued up work delegates for async transports. This also includes
/// currently active ones, they are removed from the queue on completion,
/// which is relied on by the control message receive fiber (the main one)
/// to decide whether to immediately start executing items or not.
// TODO: This should really be of some queue type, not an array slice, but
// std.container doesn't have anything.
Work[][TAsyncTransport] workQueues_;
/// The total number of work items not yet finished (queued and currently
/// executed) and delays not yet executed.
uint queuedCount_;
/// Protects queuedCount_.
Mutex queuedCountMutex_;
/// Triggered when queuedCount_ reaches zero, protected by queuedCountMutex_.
Condition zeroQueuedCondition_;
}
private {
timeval toTimeval(const(Duration) dur) {
timeval tv;
dur.split!("seconds", "usecs")(tv.tv_sec, tv.tv_usec);
return tv;
}
/**
* Returns the libevent flags combination to represent a given TAsyncEventType.
*/
short libeventEventType(TAsyncEventType type) {
final switch (type) {
case TAsyncEventType.READ:
return EV_READ | EV_ET;
case TAsyncEventType.WRITE:
return EV_WRITE | EV_ET;
}
}
}

View file

@ -0,0 +1,357 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.async.socket;
import core.thread : Fiber;
import core.time : dur, Duration;
import std.array : empty;
import std.conv : to;
import std.exception : enforce;
import std.socket;
import thrift.base;
import thrift.async.base;
import thrift.transport.base;
import thrift.transport.socket : TSocketBase;
import thrift.internal.endian;
import thrift.internal.socket;
version (Windows) {
import std.c.windows.winsock : connect;
} else version (Posix) {
import core.sys.posix.sys.socket : connect;
} else static assert(0, "Don't know connect on this platform.");
/**
* Non-blocking socket implementation of the TTransport interface.
*
* Whenever a socket operation would block, TAsyncSocket registers a callback
* with the specified TAsyncSocketManager and yields.
*
* As for thrift.transport.socket, due to the limitations of std.socket,
* currently only TCP/IP sockets are supported (i.e. Unix domain sockets are
* not).
*/
class TAsyncSocket : TSocketBase, TAsyncTransport {
/**
* Constructor that takes an already created, connected (!) socket.
*
* Params:
* asyncManager = The TAsyncSocketManager to use for non-blocking I/O.
* socket = Already created, connected socket object. Will be switched to
* non-blocking mode if it isn't already.
*/
this(TAsyncSocketManager asyncManager, Socket socket) {
asyncManager_ = asyncManager;
socket.blocking = false;
super(socket);
}
/**
* Creates a new unconnected socket that will connect to the given host
* on the given port.
*
* Params:
* asyncManager = The TAsyncSocketManager to use for non-blocking I/O.
* host = Remote host.
* port = Remote port.
*/
this(TAsyncSocketManager asyncManager, string host, ushort port) {
asyncManager_ = asyncManager;
super(host, port);
}
override TAsyncManager asyncManager() @property {
return asyncManager_;
}
/**
* Asynchronously connects the socket.
*
* Completes without blocking and defers further operations on the socket
* until the connection is established. If connecting fails, this is
* currently not indicated in any way other than every call to read/write
* failing.
*/
override void open() {
if (isOpen) return;
enforce(!host_.empty, new TTransportException(
"Cannot open null host.", TTransportException.Type.NOT_OPEN));
enforce(port_ != 0, new TTransportException(
"Cannot open with null port.", TTransportException.Type.NOT_OPEN));
// Cannot use std.socket.Socket.connect here because it hides away
// EINPROGRESS/WSAWOULDBLOCK.
Address addr;
try {
// Currently, we just go with the first address returned, could be made
// more intelligent though IPv6?
addr = getAddress(host_, port_)[0];
} catch (Exception e) {
throw new TTransportException(`Unable to resolve host "` ~ host_ ~ `".`,
TTransportException.Type.NOT_OPEN, __FILE__, __LINE__, e);
}
socket_ = new TcpSocket(addr.addressFamily);
socket_.blocking = false;
setSocketOpts();
auto errorCode = connect(socket_.handle, addr.name(), addr.nameLen());
if (errorCode == 0) {
// If the connection could be established immediately, just return. I
// don't know if this ever happens.
return;
}
auto errno = getSocketErrno();
if (errno != CONNECT_INPROGRESS_ERRNO) {
throw new TTransportException(`Could not establish connection to "` ~
host_ ~ `": ` ~ socketErrnoString(errno),
TTransportException.Type.NOT_OPEN);
}
// This is the expected case: connect() signalled that the connection
// is being established in the background. Queue up a work item with the
// async manager which just defers any other operations on this
// TAsyncSocket instance until the socket is ready.
asyncManager_.execute(this,
{
auto fiber = Fiber.getThis();
TAsyncEventReason reason = void;
asyncManager_.addOneshotListener(socket_, TAsyncEventType.WRITE,
connectTimeout,
scopedDelegate((TAsyncEventReason r){ reason = r; fiber.call(); })
);
Fiber.yield();
if (reason == TAsyncEventReason.TIMED_OUT) {
// Close the connection, so that subsequent work items fail immediately.
closeImmediately();
return;
}
int errorCode = void;
socket_.getOption(SocketOptionLevel.SOCKET, cast(SocketOption)SO_ERROR,
errorCode);
if (errorCode) {
logInfo("Could not connect TAsyncSocket: %s",
socketErrnoString(errorCode));
// Close the connection, so that subsequent work items fail immediately.
closeImmediately();
return;
}
}
);
}
/**
* Closes the socket.
*
* Will block until all currently active operations are finished before the
* socket is closed.
*/
override void close() {
if (!isOpen) return;
import core.sync.condition;
import core.sync.mutex;
auto doneMutex = new Mutex;
auto doneCond = new Condition(doneMutex);
synchronized (doneMutex) {
asyncManager_.execute(this,
scopedDelegate(
{
closeImmediately();
synchronized (doneMutex) doneCond.notifyAll();
}
)
);
doneCond.wait();
}
}
override bool peek() {
if (!isOpen) return false;
ubyte buf;
auto r = socket_.receive((&buf)[0..1], SocketFlags.PEEK);
if (r == Socket.ERROR) {
auto lastErrno = getSocketErrno();
static if (connresetOnPeerShutdown) {
if (lastErrno == ECONNRESET) {
closeImmediately();
return false;
}
}
throw new TTransportException("Peeking into socket failed: " ~
socketErrnoString(lastErrno), TTransportException.Type.UNKNOWN);
}
return (r > 0);
}
override size_t read(ubyte[] buf) {
enforce(isOpen, new TTransportException(
"Cannot read if socket is not open.", TTransportException.Type.NOT_OPEN));
typeof(getSocketErrno()) lastErrno;
auto r = yieldOnBlock(socket_.receive(cast(void[])buf),
TAsyncEventType.READ);
// If recv went fine, immediately return.
if (r >= 0) return r;
// Something went wrong, find out how to handle it.
lastErrno = getSocketErrno();
static if (connresetOnPeerShutdown) {
// See top comment.
if (lastErrno == ECONNRESET) {
return 0;
}
}
throw new TTransportException("Receiving from socket failed: " ~
socketErrnoString(lastErrno), TTransportException.Type.UNKNOWN);
}
override void write(in ubyte[] buf) {
size_t sent;
while (sent < buf.length) {
sent += writeSome(buf[sent .. $]);
}
assert(sent == buf.length);
}
override size_t writeSome(in ubyte[] buf) {
enforce(isOpen, new TTransportException(
"Cannot write if socket is not open.", TTransportException.Type.NOT_OPEN));
auto r = yieldOnBlock(socket_.send(buf), TAsyncEventType.WRITE);
// Everything went well, just return the number of bytes written.
if (r > 0) return r;
// Handle error conditions.
if (r < 0) {
auto lastErrno = getSocketErrno();
auto type = TTransportException.Type.UNKNOWN;
if (isSocketCloseErrno(lastErrno)) {
type = TTransportException.Type.NOT_OPEN;
closeImmediately();
}
throw new TTransportException("Sending to socket failed: " ~
socketErrnoString(lastErrno), type);
}
// send() should never return 0.
throw new TTransportException("Sending to socket failed (0 bytes written).",
TTransportException.Type.UNKNOWN);
}
/// The amount of time in which a conncetion must be established before the
/// open() call times out.
Duration connectTimeout = dur!"seconds"(5);
private:
void closeImmediately() {
socket_.close();
socket_ = null;
}
T yieldOnBlock(T)(lazy T call, TAsyncEventType eventType) {
while (true) {
auto result = call();
if (result != Socket.ERROR || getSocketErrno() != WOULD_BLOCK_ERRNO) return result;
// We got an EAGAIN result, register a callback to return here once some
// event happens and yield.
Duration timeout = void;
final switch (eventType) {
case TAsyncEventType.READ:
timeout = recvTimeout_;
break;
case TAsyncEventType.WRITE:
timeout = sendTimeout_;
break;
}
auto fiber = Fiber.getThis();
assert(fiber, "Current fiber null not running in TAsyncManager?");
TAsyncEventReason eventReason = void;
asyncManager_.addOneshotListener(socket_, eventType, timeout,
scopedDelegate((TAsyncEventReason reason) {
eventReason = reason;
fiber.call();
})
);
// Yields execution back to the async manager, will return back here once
// the above listener is called.
Fiber.yield();
if (eventReason == TAsyncEventReason.TIMED_OUT) {
// If we are cancelling the request due to a timed out operation, the
// connection is in an undefined state, because the server could decide
// to send the requested data later, or we could have already been half-
// way into writing a request. Thus, we close the connection to make any
// possibly queued up work items fail immediately. Besides, the server
// is not very likely to immediately recover after a socket-level
// timeout has expired anyway.
closeImmediately();
throw new TTransportException("Timed out while waiting for socket " ~
"to get ready to " ~ to!string(eventType) ~ ".",
TTransportException.Type.TIMED_OUT);
}
}
}
/// The TAsyncSocketManager to use for non-blocking I/O.
TAsyncSocketManager asyncManager_;
}
private {
// std.socket doesn't include SO_ERROR for reasons unknown.
version (linux) {
enum SO_ERROR = 4;
} else version (OSX) {
enum SO_ERROR = 0x1007;
} else version (FreeBSD) {
enum SO_ERROR = 0x1007;
} else version (Win32) {
import std.c.windows.winsock : SO_ERROR;
} else static assert(false, "Don't know SO_ERROR on this platform.");
// This hack forces a delegate literal to be scoped, even if it is passed to
// a function accepting normal delegates as well. DMD likes to allocate the
// context on the heap anyway, but it seems to work for LDC.
import std.traits : isDelegate;
auto scopedDelegate(D)(scope D d) if (isDelegate!D) {
return d;
}
}

View file

@ -0,0 +1,292 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.async.ssl;
import core.thread : Fiber;
import core.time : Duration;
import std.array : empty;
import std.conv : to;
import std.exception : enforce;
import std.socket;
import deimos.openssl.err;
import deimos.openssl.ssl;
import thrift.base;
import thrift.async.base;
import thrift.async.socket;
import thrift.internal.ssl;
import thrift.internal.ssl_bio;
import thrift.transport.base;
import thrift.transport.ssl;
/**
* Provides SSL/TLS encryption for async sockets.
*
* This implementation should be considered experimental, as it context-switches
* between fibers from within OpenSSL calls, and the safety of this has not yet
* been verified.
*
* For obvious reasons (the SSL connection is stateful), more than one instance
* should never be used on a given socket at the same time.
*/
// Note: This could easily be extended to other transports in the future as well.
// There are only two parts of the implementation which don't work with a generic
// TTransport: 1) the certificate verification, for which peer name/address are
// needed from the socket, and 2) the connection shutdown, where the associated
// async manager is needed because close() is not usually called from within a
// work item.
final class TAsyncSSLSocket : TBaseTransport {
/**
* Constructor.
*
* Params:
* context = The SSL socket context to use. A reference to it is stored so
* that it does not get cleaned up while the socket is used.
* transport = The underlying async network transport to use for
* communication.
*/
this(TAsyncSocket underlyingSocket, TSSLContext context) {
socket_ = underlyingSocket;
context_ = context;
serverSide_ = context.serverSide;
accessManager_ = context.accessManager;
}
override bool isOpen() @property {
if (ssl_ is null || !socket_.isOpen) return false;
auto shutdown = SSL_get_shutdown(ssl_);
bool shutdownReceived = (shutdown & SSL_RECEIVED_SHUTDOWN) != 0;
bool shutdownSent = (shutdown & SSL_SENT_SHUTDOWN) != 0;
return !(shutdownReceived && shutdownSent);
}
override bool peek() {
if (!isOpen) return false;
checkHandshake();
byte bt = void;
auto rc = SSL_peek(ssl_, &bt, bt.sizeof);
sslEnforce(rc >= 0, "SSL_peek");
if (rc == 0) {
ERR_clear_error();
}
return (rc > 0);
}
override void open() {
enforce(!serverSide_, "Cannot open a server-side SSL socket.");
if (isOpen) return;
if (ssl_) {
// If the underlying socket was automatically closed because of an error
// (i.e. close() was called from inside a socket method), we can land
// here with the SSL object still allocated; delete it here.
cleanupSSL();
}
socket_.open();
}
override void close() {
if (!isOpen) return;
if (ssl_ !is null) {
// SSL needs to send/receive data over the socket as part of the shutdown
// protocol, so we must execute the calls in the context of the associated
// async manager. On the other hand, TTransport clients expect the socket
// to be closed when close() returns, so we have to block until the
// shutdown work item has been executed.
import core.sync.condition;
import core.sync.mutex;
int rc = void;
auto doneMutex = new Mutex;
auto doneCond = new Condition(doneMutex);
synchronized (doneMutex) {
socket_.asyncManager.execute(socket_, {
rc = SSL_shutdown(ssl_);
if (rc == 0) {
rc = SSL_shutdown(ssl_);
}
synchronized (doneMutex) doneCond.notifyAll();
});
doneCond.wait();
}
if (rc < 0) {
// Do not throw an exception here as leaving the transport "open" will
// probably produce only more errors, and the chance we can do
// something about the error e.g. by retrying is very low.
logError("Error while shutting down SSL: %s", getSSLException());
}
cleanupSSL();
}
socket_.close();
}
override size_t read(ubyte[] buf) {
checkHandshake();
auto rc = SSL_read(ssl_, buf.ptr, cast(int)buf.length);
sslEnforce(rc >= 0, "SSL_read");
return rc;
}
override void write(in ubyte[] buf) {
checkHandshake();
// Loop in case SSL_MODE_ENABLE_PARTIAL_WRITE is set in SSL_CTX.
size_t written = 0;
while (written < buf.length) {
auto bytes = SSL_write(ssl_, buf.ptr + written,
cast(int)(buf.length - written));
sslEnforce(bytes > 0, "SSL_write");
written += bytes;
}
}
override void flush() {
checkHandshake();
auto bio = SSL_get_wbio(ssl_);
enforce(bio !is null, new TSSLException("SSL_get_wbio returned null"));
auto rc = BIO_flush(bio);
sslEnforce(rc == 1, "BIO_flush");
}
/**
* Whether to use client or server side SSL handshake protocol.
*/
bool serverSide() @property const {
return serverSide_;
}
/// Ditto
void serverSide(bool value) @property {
serverSide_ = value;
}
/**
* The access manager to use.
*/
void accessManager(TAccessManager value) @property {
accessManager_ = value;
}
private:
/**
* If the condition is false, cleans up the SSL connection and throws the
* exception for the last SSL error.
*/
void sslEnforce(bool condition, string location) {
if (!condition) {
// We need to fetch the error first, as the error stack will be cleaned
// when shutting down SSL.
auto e = getSSLException(location);
cleanupSSL();
throw e;
}
}
/**
* Frees the SSL connection object and clears the SSL error state.
*/
void cleanupSSL() {
SSL_free(ssl_);
ssl_ = null;
ERR_remove_state(0);
}
/**
* Makes sure the SSL connection is up and running, and initializes it if not.
*/
void checkHandshake() {
enforce(socket_.isOpen, new TTransportException(
TTransportException.Type.NOT_OPEN));
if (ssl_ !is null) return;
ssl_ = context_.createSSL();
auto bio = createTTransportBIO(socket_, false);
SSL_set_bio(ssl_, bio, bio);
int rc = void;
if (serverSide_) {
rc = SSL_accept(ssl_);
} else {
rc = SSL_connect(ssl_);
}
enforce(rc > 0, getSSLException());
auto addr = socket_.getPeerAddress();
authorize(ssl_, accessManager_, addr,
(serverSide_ ? addr.toHostNameString() : socket_.host));
}
TAsyncSocket socket_;
bool serverSide_;
SSL* ssl_;
TSSLContext context_;
TAccessManager accessManager_;
}
/**
* Wraps passed TAsyncSocket instances into TAsyncSSLSockets.
*
* Typically used with TAsyncClient. As an unfortunate consequence of the
* async client design, the passed transports cannot be statically verified to
* be of type TAsyncSocket. Instead, the type is verified at runtime if a
* transport of an unexpected type is passed to getTransport(), it fails,
* throwing a TTransportException.
*
* Example:
* ---
* auto context = nwe TSSLContext();
* ... // Configure SSL context.
* auto factory = new TAsyncSSLSocketFactory(context);
*
* auto socket = new TAsyncSocket(someAsyncManager, host, port);
* socket.open();
*
* auto client = new TAsyncClient!Service(transport, factory,
* new TBinaryProtocolFactory!());
* ---
*/
class TAsyncSSLSocketFactory : TTransportFactory {
///
this(TSSLContext context) {
context_ = context;
}
override TAsyncSSLSocket getTransport(TTransport transport) {
auto socket = cast(TAsyncSocket)transport;
enforce(socket, new TTransportException(
"TAsyncSSLSocketFactory requires a TAsyncSocket to work on, not a " ~
to!string(typeid(transport)) ~ ".",
TTransportException.Type.INTERNAL_ERROR
));
return new TAsyncSSLSocket(socket, context_);
}
private:
TSSLContext context_;
}

View file

@ -0,0 +1,123 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.base;
/**
* Common base class for all Thrift exceptions.
*/
class TException : Exception {
///
this(string msg = "", string file = __FILE__, size_t line = __LINE__,
Throwable next = null)
{
super(msg, file, line, next);
}
}
/**
* An operation failed because one or more sub-tasks failed.
*/
class TCompoundOperationException : TException {
///
this(string msg, Exception[] exceptions, string file = __FILE__,
size_t line = __LINE__, Throwable next = null)
{
super(msg, file, line, next);
this.exceptions = exceptions;
}
/// The exceptions thrown by the children of the operation. If applicable,
/// the list is ordered in the same way the exceptions occurred.
Exception[] exceptions;
}
/// The Thrift version string, used for informative purposes.
// Note: This is currently hardcoded, but will likely be filled in by the build
// system in future versions.
enum VERSION = "0.10.0";
/**
* Functions used for logging inside Thrift.
*
* By default, the formatted messages are written to stdout/stderr, but this
* behavior can be overwritten by providing custom g_{Info, Error}LogSink
* handlers.
*
* Examples:
* ---
* logInfo("An informative message.");
* logError("Some error occurred: %s", e);
* ---
*/
alias logFormatted!g_infoLogSink logInfo;
alias logFormatted!g_errorLogSink logError; /// Ditto
/**
* Error and info log message sinks.
*
* These delegates are called with the log message passed as const(char)[]
* argument, and can be overwritten to hook the Thrift libraries up with a
* custom logging system. By default, they forward all output to stdout/stderr.
*/
__gshared void delegate(const(char)[]) g_infoLogSink;
__gshared void delegate(const(char)[]) g_errorLogSink; /// Ditto
shared static this() {
import std.stdio;
g_infoLogSink = (const(char)[] text) {
stdout.writeln(text);
};
g_errorLogSink = (const(char)[] text) {
stderr.writeln(text);
};
}
// This should be private, if it could still be used through the aliases then.
template logFormatted(alias target) {
void logFormatted(string file = __FILE__, int line = __LINE__,
T...)(string fmt, T args) if (
__traits(compiles, { target(""); })
) {
import std.format, std.stdio;
if (target !is null) {
scope(exit) g_formatBuffer.clear();
// Phobos @@BUG@@: If the empty string put() is removed, Appender.data
// stays empty.
g_formatBuffer.put("");
formattedWrite(g_formatBuffer, "%s:%s: ", file, line);
static if (T.length == 0) {
g_formatBuffer.put(fmt);
} else {
formattedWrite(g_formatBuffer, fmt, args);
}
target(g_formatBuffer.data);
}
}
}
private {
// Use a global, but thread-local buffer for constructing log messages.
import std.array : Appender;
Appender!(char[]) g_formatBuffer;
}

View file

@ -0,0 +1,255 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.codegen.async_client;
import std.conv : text, to;
import std.traits : ParameterStorageClass, ParameterStorageClassTuple,
ParameterTypeTuple, ReturnType;
import thrift.base;
import thrift.async.base;
import thrift.codegen.base;
import thrift.codegen.client;
import thrift.internal.codegen;
import thrift.internal.ctfe;
import thrift.protocol.base;
import thrift.transport.base;
import thrift.util.cancellation;
import thrift.util.future;
/**
* Asynchronous Thrift service client which returns the results as TFutures an
* uses a TAsyncManager to perform the actual work.
*
* TAsyncClientBase serves as a supertype for all TAsyncClients for the same
* service, which might be instantiated with different concrete protocol types
* (there is no covariance for template type parameters), and extends
* TFutureInterface!Interface. If Interface is derived from another service
* BaseInterface, it also extends TAsyncClientBase!BaseInterface.
*
* TAsyncClient implements TAsyncClientBase and offers two constructors with
* the following signatures:
* ---
* this(TAsyncTransport trans, TTransportFactory tf, TProtocolFactory pf);
* this(TAsyncTransport trans, TTransportFactory itf, TTransportFactory otf,
* TProtocolFactory ipf, TProtocolFactory opf);
* ---
*
* Again, if Interface represents a derived Thrift service,
* TAsyncClient!Interface is also derived from TAsyncClient!BaseInterface.
*
* TAsyncClient can exclusively be used with TAsyncTransports, as it needs to
* access the associated TAsyncManager. To set up any wrapper transports
* (e.g. buffered, framed) on top of it and to instanciate the protocols to use,
* TTransportFactory and TProtocolFactory instances are passed to the
* constructors the three argument constructor is a shortcut if the same
* transport and protocol are to be used for both input and output, which is
* the most common case.
*
* If the same transport factory is passed for both input and output transports,
* only a single wrapper transport will be created and used for both directions.
* This allows easy implementation of protocols like SSL.
*
* Just as TClient does, TAsyncClient also takes two optional template
* arguments which can be used for specifying the actual TProtocol
* implementation used for optimization purposes, as virtual calls can
* completely be eliminated then. If the actual types of the protocols
* instantiated by the factories used does not match the ones statically
* specified in the template parameters, a TException is thrown during
* construction.
*
* Example:
* ---
* // A simple Thrift service.
* interface Foo { int foo(); }
*
* // Create a TAsyncSocketManager thrift.async.libevent is used for this
* // example.
* auto manager = new TLibeventAsyncManager;
*
* // Set up an async transport to use.
* auto socket = new TAsyncSocket(manager, host, port);
*
* // Create a client instance.
* auto client = new TAsyncClient!Foo(
* socket,
* new TBufferedTransportFactory, // Wrap the socket in a TBufferedTransport.
* new TBinaryProtocolFactory!() // Use the Binary protocol.
* );
*
* // Call foo and use the returned future.
* auto result = client.foo();
* pragma(msg, typeof(result)); // TFuture!int
* int resultValue = result.waitGet(); // Waits until the result is available.
* ---
*/
interface TAsyncClientBase(Interface) if (isBaseService!Interface) :
TFutureInterface!Interface
{
/**
* The underlying TAsyncTransport used by this client instance.
*/
TAsyncTransport transport() @property;
}
/// Ditto
interface TAsyncClientBase(Interface) if (isDerivedService!Interface) :
TAsyncClientBase!(BaseService!Interface), TFutureInterface!Interface
{}
/// Ditto
template TAsyncClient(Interface, InputProtocol = TProtocol, OutputProtocol = void) if (
isService!Interface && isTProtocol!InputProtocol &&
(isTProtocol!OutputProtocol || is(OutputProtocol == void))
) {
mixin({
static if (isDerivedService!Interface) {
string code = "class TAsyncClient : " ~
"TAsyncClient!(BaseService!Interface, InputProtocol, OutputProtocol), " ~
"TAsyncClientBase!Interface {\n";
code ~= q{
this(TAsyncTransport trans, TTransportFactory tf, TProtocolFactory pf) {
this(trans, tf, tf, pf, pf);
}
this(TAsyncTransport trans, TTransportFactory itf,
TTransportFactory otf, TProtocolFactory ipf, TProtocolFactory opf
) {
super(trans, itf, otf, ipf, opf);
client_ = new typeof(client_)(iprot_, oprot_);
}
private TClient!(Interface, IProt, OProt) client_;
};
} else {
string code = "class TAsyncClient : TAsyncClientBase!Interface {";
code ~= q{
alias InputProtocol IProt;
static if (isTProtocol!OutputProtocol) {
alias OutputProtocol OProt;
} else {
static assert(is(OutputProtocol == void));
alias InputProtocol OProt;
}
this(TAsyncTransport trans, TTransportFactory tf, TProtocolFactory pf) {
this(trans, tf, tf, pf, pf);
}
this(TAsyncTransport trans, TTransportFactory itf,
TTransportFactory otf, TProtocolFactory ipf, TProtocolFactory opf
) {
import std.exception;
transport_ = trans;
auto ip = itf.getTransport(trans);
TTransport op = void;
if (itf == otf) {
op = ip;
} else {
op = otf.getTransport(trans);
}
auto iprot = ipf.getProtocol(ip);
iprot_ = cast(IProt)iprot;
enforce(iprot_, new TException(text("Input protocol not of the " ~
"specified concrete type (", IProt.stringof, ").")));
auto oprot = opf.getProtocol(op);
oprot_ = cast(OProt)oprot;
enforce(oprot_, new TException(text("Output protocol not of the " ~
"specified concrete type (", OProt.stringof, ").")));
client_ = new typeof(client_)(iprot_, oprot_);
}
override TAsyncTransport transport() @property {
return transport_;
}
protected TAsyncTransport transport_;
protected IProt iprot_;
protected OProt oprot_;
private TClient!(Interface, IProt, OProt) client_;
};
}
foreach (methodName;
FilterMethodNames!(Interface, __traits(derivedMembers, Interface))
) {
string[] paramList;
string[] paramNames;
foreach (i, _; ParameterTypeTuple!(mixin("Interface." ~ methodName))) {
immutable paramName = "param" ~ to!string(i + 1);
immutable storage = ParameterStorageClassTuple!(
mixin("Interface." ~ methodName))[i];
paramList ~= ((storage & ParameterStorageClass.ref_) ? "ref " : "") ~
"ParameterTypeTuple!(Interface." ~ methodName ~ ")[" ~
to!string(i) ~ "] " ~ paramName;
paramNames ~= paramName;
}
paramList ~= "TCancellation cancellation = null";
immutable returnTypeCode = "ReturnType!(Interface." ~ methodName ~ ")";
code ~= "TFuture!(" ~ returnTypeCode ~ ") " ~ methodName ~ "(" ~
ctfeJoin(paramList) ~ ") {\n";
// Create the future instance that will repesent the result.
code ~= "auto promise = new TPromise!(" ~ returnTypeCode ~ ");\n";
// Prepare delegate which executes the TClient method call.
code ~= "auto work = {\n";
code ~= "try {\n";
code ~= "static if (is(ReturnType!(Interface." ~ methodName ~
") == void)) {\n";
code ~= "client_." ~ methodName ~ "(" ~ ctfeJoin(paramNames) ~ ");\n";
code ~= "promise.succeed();\n";
code ~= "} else {\n";
code ~= "auto result = client_." ~ methodName ~ "(" ~
ctfeJoin(paramNames) ~ ");\n";
code ~= "promise.succeed(result);\n";
code ~= "}\n";
code ~= "} catch (Exception e) {\n";
code ~= "promise.fail(e);\n";
code ~= "}\n";
code ~= "};\n";
// If the request is cancelled, set the result promise to cancelled
// as well. This could be moved into an additional TAsyncWorkItem
// delegate parameter.
code ~= q{
if (cancellation) {
cancellation.triggering.addCallback({
promise.cancel();
});
}
};
// Enqueue the work item and immediately return the promise (resp. its
// future interface).
code ~= "transport_.asyncManager.execute(transport_, work, cancellation);\n";
code ~= "return promise;\n";
code ~= "}\n";
}
code ~= "}\n";
return code;
}());
}

View file

@ -0,0 +1,906 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/**
* Utilities for asynchronously querying multiple servers, building on
* TAsyncClient.
*
* Terminology note: The names of the artifacts defined in this module are
* derived from »client pool«, because they operate on a pool of
* TAsyncClients. However, from a architectural point of view, they often
* represent a pool of hosts a Thrift client application communicates with
* using RPC calls.
*/
module thrift.codegen.async_client_pool;
import core.sync.mutex;
import core.time : Duration, dur;
import std.algorithm : map;
import std.array : array, empty;
import std.exception : enforce;
import std.traits : ParameterTypeTuple, ReturnType;
import thrift.base;
import thrift.codegen.base;
import thrift.codegen.async_client;
import thrift.internal.algorithm;
import thrift.internal.codegen;
import thrift.util.awaitable;
import thrift.util.cancellation;
import thrift.util.future;
import thrift.internal.resource_pool;
/**
* Represents a generic client pool which implements TFutureInterface!Interface
* using multiple TAsyncClients.
*/
interface TAsyncClientPoolBase(Interface) if (isService!Interface) :
TFutureInterface!Interface
{
/// Shorthand for the client type this pool operates on.
alias TAsyncClientBase!Interface Client;
/**
* Adds a client to the pool.
*/
void addClient(Client client);
/**
* Removes a client from the pool.
*
* Returns: Whether the client was found in the pool.
*/
bool removeClient(Client client);
/**
* Called to determine whether an exception comes from a client from the
* pool not working properly, or if it an exception thrown at the
* application level.
*
* If the delegate returns true, the server/connection is considered to be
* at fault, if it returns false, the exception is just passed on to the
* caller.
*
* By default, returns true for instances of TTransportException and
* TApplicationException, false otherwise.
*/
bool delegate(Exception) rpcFaultFilter() const @property;
void rpcFaultFilter(bool delegate(Exception)) @property; /// Ditto
/**
* Whether to open the underlying transports of a client before trying to
* execute a method if they are not open. This is usually desirable
* because it allows e.g. to automatically reconnect to a remote server
* if the network connection is dropped.
*
* Defaults to true.
*/
bool reopenTransports() const @property;
void reopenTransports(bool) @property; /// Ditto
}
immutable bool delegate(Exception) defaultRpcFaultFilter;
static this() {
defaultRpcFaultFilter = (Exception e) {
import thrift.protocol.base;
import thrift.transport.base;
return (
(cast(TTransportException)e !is null) ||
(cast(TApplicationException)e !is null)
);
};
}
/**
* A TAsyncClientPoolBase implementation which queries multiple servers in a
* row until a request succeeds, the result of which is then returned.
*
* The definition of »success« can be customized using the rpcFaultFilter()
* delegate property. If it is non-null and calling it for an exception set by
* a failed method invocation returns true, the error is considered to be
* caused by the RPC layer rather than the application layer, and the next
* server in the pool is tried. If there are no more clients to try, the
* operation is marked as failed with a TCompoundOperationException.
*
* If a TAsyncClient in the pool fails with an RPC exception for a number of
* consecutive tries, it is temporarily disabled (not tried any longer) for
* a certain duration. Both the limit and the timeout can be configured. If all
* clients fail (and keepTrying is false), the operation fails with a
* TCompoundOperationException which contains the collected RPC exceptions.
*/
final class TAsyncClientPool(Interface) if (isService!Interface) :
TAsyncClientPoolBase!Interface
{
///
this(Client[] clients) {
pool_ = new TResourcePool!Client(clients);
rpcFaultFilter_ = defaultRpcFaultFilter;
reopenTransports_ = true;
}
/+override+/ void addClient(Client client) {
pool_.add(client);
}
/+override+/ bool removeClient(Client client) {
return pool_.remove(client);
}
/**
* Whether to keep trying to find a working client if all have failed in a
* row.
*
* Defaults to false.
*/
bool keepTrying() const @property {
return pool_.cycle;
}
/// Ditto
void keepTrying(bool value) @property {
pool_.cycle = value;
}
/**
* Whether to use a random permutation of the client pool on every call to
* execute(). This can be used e.g. as a simple form of load balancing.
*
* Defaults to true.
*/
bool permuteClients() const @property {
return pool_.permute;
}
/// Ditto
void permuteClients(bool value) @property {
pool_.permute = value;
}
/**
* The number of consecutive faults after which a client is disabled until
* faultDisableDuration has passed. 0 to never disable clients.
*
* Defaults to 0.
*/
ushort faultDisableCount() const @property {
return pool_.faultDisableCount;
}
/// Ditto
void faultDisableCount(ushort value) @property {
pool_.faultDisableCount = value;
}
/**
* The duration for which a client is no longer considered after it has
* failed too often.
*
* Defaults to one second.
*/
Duration faultDisableDuration() const @property {
return pool_.faultDisableDuration;
}
/// Ditto
void faultDisableDuration(Duration value) @property {
pool_.faultDisableDuration = value;
}
/+override+/ bool delegate(Exception) rpcFaultFilter() const @property {
return rpcFaultFilter_;
}
/+override+/ void rpcFaultFilter(bool delegate(Exception) value) @property {
rpcFaultFilter_ = value;
}
/+override+/ bool reopenTransports() const @property {
return reopenTransports_;
}
/+override+/ void reopenTransports(bool value) @property {
reopenTransports_ = value;
}
mixin(fallbackPoolForwardCode!Interface());
protected:
// The actual worker implementation to which RPC method calls are forwarded.
auto executeOnPool(string method, Args...)(Args args,
TCancellation cancellation
) {
auto clients = pool_[];
if (clients.empty) {
throw new TException("No clients available to try.");
}
auto promise = new TPromise!(ReturnType!(MemberType!(Interface, method)));
Exception[] rpcExceptions;
void tryNext() {
while (clients.empty) {
Client next;
Duration waitTime;
if (clients.willBecomeNonempty(next, waitTime)) {
if (waitTime > dur!"hnsecs"(0)) {
if (waitTime < dur!"usecs"(10)) {
import core.thread;
Thread.sleep(waitTime);
} else {
next.transport.asyncManager.delay(waitTime, { tryNext(); });
return;
}
}
} else {
promise.fail(new TCompoundOperationException("All clients failed.",
rpcExceptions));
return;
}
}
auto client = clients.front;
clients.popFront;
if (reopenTransports) {
if (!client.transport.isOpen) {
try {
client.transport.open();
} catch (Exception e) {
pool_.recordFault(client);
tryNext();
return;
}
}
}
auto future = mixin("client." ~ method)(args, cancellation);
future.completion.addCallback({
if (future.status == TFutureStatus.CANCELLED) {
promise.cancel();
return;
}
auto e = future.getException();
if (e) {
if (rpcFaultFilter_ && rpcFaultFilter_(e)) {
pool_.recordFault(client);
rpcExceptions ~= e;
tryNext();
return;
}
}
pool_.recordSuccess(client);
promise.complete(future);
});
}
tryNext();
return promise;
}
private:
TResourcePool!Client pool_;
bool delegate(Exception) rpcFaultFilter_;
bool reopenTransports_;
}
/**
* TAsyncClientPool construction helper to avoid having to explicitly
* specify the interface type, i.e. to allow the constructor being called
* using IFTI (see $(DMDBUG 6082, D Bugzilla enhancement request 6082)).
*/
TAsyncClientPool!Interface tAsyncClientPool(Interface)(
TAsyncClientBase!Interface[] clients
) if (isService!Interface) {
return new typeof(return)(clients);
}
private {
// Cannot use an anonymous delegate literal for this because they aren't
// allowed in class scope.
string fallbackPoolForwardCode(Interface)() {
string code = "";
foreach (methodName; AllMemberMethodNames!Interface) {
enum qn = "Interface." ~ methodName;
code ~= "TFuture!(ReturnType!(" ~ qn ~ ")) " ~ methodName ~
"(ParameterTypeTuple!(" ~ qn ~ ") args, TCancellation cancellation = null) {\n";
code ~= "return executeOnPool!(`" ~ methodName ~ "`)(args, cancellation);\n";
code ~= "}\n";
}
return code;
}
}
/**
* A TAsyncClientPoolBase implementation which queries multiple servers at
* the same time and returns the first success response.
*
* The definition of »success« can be customized using the rpcFaultFilter()
* delegate property. If it is non-null and calling it for an exception set by
* a failed method invocation returns true, the error is considered to be
* caused by the RPC layer rather than the application layer, and the next
* server in the pool is tried. If all clients fail, the operation is marked
* as failed with a TCompoundOperationException.
*/
final class TAsyncFastestClientPool(Interface) if (isService!Interface) :
TAsyncClientPoolBase!Interface
{
///
this(Client[] clients) {
clients_ = clients;
rpcFaultFilter_ = defaultRpcFaultFilter;
reopenTransports_ = true;
}
/+override+/ void addClient(Client client) {
clients_ ~= client;
}
/+override+/ bool removeClient(Client client) {
auto oldLength = clients_.length;
clients_ = removeEqual(clients_, client);
return clients_.length < oldLength;
}
/+override+/ bool delegate(Exception) rpcFaultFilter() const @property {
return rpcFaultFilter_;
}
/+override+/ void rpcFaultFilter(bool delegate(Exception) value) @property {
rpcFaultFilter_ = value;
}
/+override+/bool reopenTransports() const @property {
return reopenTransports_;
}
/+override+/ void reopenTransports(bool value) @property {
reopenTransports_ = value;
}
mixin(fastestPoolForwardCode!Interface());
private:
Client[] clients_;
bool delegate(Exception) rpcFaultFilter_;
bool reopenTransports_;
}
/**
* TAsyncFastestClientPool construction helper to avoid having to explicitly
* specify the interface type, i.e. to allow the constructor being called
* using IFTI (see $(DMDBUG 6082, D Bugzilla enhancement request 6082)).
*/
TAsyncFastestClientPool!Interface tAsyncFastestClientPool(Interface)(
TAsyncClientBase!Interface[] clients
) if (isService!Interface) {
return new typeof(return)(clients);
}
private {
// Cannot use an anonymous delegate literal for this because they aren't
// allowed in class scope.
string fastestPoolForwardCode(Interface)() {
string code = "";
foreach (methodName; AllMemberMethodNames!Interface) {
enum qn = "Interface." ~ methodName;
code ~= "TFuture!(ReturnType!(" ~ qn ~ ")) " ~ methodName ~
"(ParameterTypeTuple!(" ~ qn ~ ") args, " ~
"TCancellation cancellation = null) {\n";
code ~= "enum methodName = `" ~ methodName ~ "`;\n";
code ~= q{
alias ReturnType!(MemberType!(Interface, methodName)) ResultType;
auto childCancellation = new TCancellationOrigin;
TFuture!ResultType[] futures;
futures.reserve(clients_.length);
foreach (c; clients_) {
if (reopenTransports) {
if (!c.transport.isOpen) {
try {
c.transport.open();
} catch (Exception e) {
continue;
}
}
}
futures ~= mixin("c." ~ methodName)(args, childCancellation);
}
return new FastestPoolJob!(ResultType)(
futures, rpcFaultFilter, cancellation, childCancellation);
};
code ~= "}\n";
}
return code;
}
final class FastestPoolJob(Result) : TFuture!Result {
this(TFuture!Result[] poolFutures, bool delegate(Exception) rpcFaultFilter,
TCancellation cancellation, TCancellationOrigin childCancellation
) {
resultPromise_ = new TPromise!Result;
poolFutures_ = poolFutures;
rpcFaultFilter_ = rpcFaultFilter;
childCancellation_ = childCancellation;
foreach (future; poolFutures) {
future.completion.addCallback({
auto f = future;
return { completionCallback(f); };
}());
if (future.status != TFutureStatus.RUNNING) {
// If the current future is already completed, we are done, don't
// bother adding callbacks for the others (they would just return
// immediately after acquiring the lock).
return;
}
}
if (cancellation) {
cancellation.triggering.addCallback({
resultPromise_.cancel();
childCancellation.trigger();
});
}
}
TFutureStatus status() const @property {
return resultPromise_.status;
}
TAwaitable completion() @property {
return resultPromise_.completion;
}
Result get() {
return resultPromise_.get();
}
Exception getException() {
return resultPromise_.getException();
}
private:
void completionCallback(TFuture!Result future) {
synchronized {
if (future.status == TFutureStatus.CANCELLED) {
assert(resultPromise_.status != TFutureStatus.RUNNING);
return;
}
if (resultPromise_.status != TFutureStatus.RUNNING) {
// The operation has already been completed. This can happen if
// another client completed first, but this callback was already
// waiting for the lock when it called cancel().
return;
}
if (future.status == TFutureStatus.FAILED) {
auto e = future.getException();
if (rpcFaultFilter_ && rpcFaultFilter_(e)) {
rpcExceptions_ ~= e;
if (rpcExceptions_.length == poolFutures_.length) {
resultPromise_.fail(new TCompoundOperationException(
"All child operations failed, unable to retrieve a result.",
rpcExceptions_
));
}
return;
}
}
// Store the result to the target promise.
resultPromise_.complete(future);
// Cancel the other futures, we would just discard their results.
// Note: We do this after we have stored the results to our promise,
// see the assert at the top of the function.
childCancellation_.trigger();
}
}
TPromise!Result resultPromise_;
TFuture!Result[] poolFutures_;
Exception[] rpcExceptions_;
bool delegate(Exception) rpcFaultFilter_;
TCancellationOrigin childCancellation_;
}
}
/**
* Allows easily aggregating results from a number of TAsyncClients.
*
* Contrary to TAsync{Fallback, Fastest}ClientPool, this class does not
* simply implement TFutureInterface!Interface. It manages a pool of clients,
* but allows the user to specify a custom accumulator function to use or to
* iterate over the results using a TFutureAggregatorRange.
*
* For each service method, TAsyncAggregator offers a method
* accepting the same arguments, and an optional TCancellation instance, just
* like with TFutureInterface. The return type, however, is a proxy object
* that offers the following methods:
* ---
* /++
* + Returns a thrift.util.future.TFutureAggregatorRange for the results of
* + the client pool method invocations.
* +
* + The [] (slicing) operator can also be used to obtain the range.
* +
* + Params:
* + timeout = A timeout to pass to the TFutureAggregatorRange constructor,
* + defaults to zero (no timeout).
* +/
* TFutureAggregatorRange!ReturnType range(Duration timeout = dur!"hnsecs"(0));
* auto opSlice() { return range(); } /// Ditto
*
* /++
* + Returns a future that gathers the results from the clients in the pool
* + and invokes a user-supplied accumulator function on them, returning its
* + return value to the client.
* +
* + In addition to the TFuture!AccumulatedType interface (where
* + AccumulatedType is the return type of the accumulator function), the
* + returned object also offers two additional methods, finish() and
* + finishGet(): By default, the accumulator functions is called after all
* + the results from the pool clients have become available. Calling finish()
* + causes the accumulator future to stop waiting for other results and
* + immediately invoking the accumulator function on the results currently
* + available. If all results are already available, finish() is a no-op.
* + finishGet() is a convenience shortcut for combining it with
* + a call to get() immediately afterwards, like waitGet() is for wait().
* +
* + The acc alias can point to any callable accepting either an array of
* + return values or an array of return values and an array of exceptions;
* + see isAccumulator!() for details. The default accumulator concatenates
* + return values that can be concatenated with each others (e.g. arrays),
* + and simply returns an array of values otherwise, failing with a
* + TCompoundOperationException no values were returned.
* +
* + The accumulator function is not executed in any of the async manager
* + worker threads associated with the async clients, but instead it is
* + invoked when the actual result is requested for the first time after the
* + operation has been completed. This also includes checking the status
* + of the operation once it is no longer running, since the accumulator
* + has to be run to determine whether the operation succeeded or failed.
* +/
* auto accumulate(alias acc = defaultAccumulator)() if (isAccumulator!acc);
* ---
*
* Example:
* ---
* // Some Thrift service.
* interface Foo {
* int foo(string name);
* byte[] bar();
* }
*
* // Create the aggregator pool client0, client1, client2 are some
* // TAsyncClient!Foo instances, but in theory could also be other
* // TFutureInterface!Foo implementations (e.g. some async client pool).
* auto pool = new TAsyncAggregator!Foo([client0, client1, client2]);
*
* foreach (val; pool.foo("baz").range(dur!"seconds"(1))) {
* // Process all the results that are available before a second has passed,
* // in the order they arrive.
* writeln(val);
* }
*
* auto sumRoots = pool.foo("baz").accumulate!((int[] vals, Exceptions[] exs){
* if (vals.empty) {
* throw new TCompoundOperationException("All clients failed", exs);
* }
*
* // Just to illustrate that the type of the values can change, convert the
* // numbers to double and sum up their roots.
* double result = 0;
* foreach (v; vals) result += sqrt(cast(double)v);
* return result;
* })();
*
* // Wait up to three seconds for the result, and then accumulate what has
* // arrived so far.
* sumRoots.completion.wait(dur!"seconds"(3));
* writeln(sumRoots.finishGet());
*
* // For scalars, the default accumulator returns an array of the values.
* pragma(msg, typeof(pool.foo("").accumulate().get()); // int[].
*
* // For lists, etc., it concatenates the results together.
* pragma(msg, typeof(pool.bar().accumulate().get())); // byte[].
* ---
*
* Note: For the accumulate!() interface, you might currently hit a »cannot use
* local '…' as parameter to non-global template accumulate«-error, see
* $(DMDBUG 5710, DMD issue 5710). If your accumulator function does not need
* to access the surrounding scope, you might want to use a function literal
* instead of a delegate to avoid the issue.
*/
class TAsyncAggregator(Interface) if (isBaseService!Interface) {
/// Shorthand for the client type this instance operates on.
alias TAsyncClientBase!Interface Client;
///
this(Client[] clients) {
clients_ = clients;
}
/// Whether to open the underlying transports of a client before trying to
/// execute a method if they are not open. This is usually desirable
/// because it allows e.g. to automatically reconnect to a remote server
/// if the network connection is dropped.
///
/// Defaults to true.
bool reopenTransports = true;
mixin AggregatorOpDispatch!();
private:
Client[] clients_;
}
/// Ditto
class TAsyncAggregator(Interface) if (isDerivedService!Interface) :
TAsyncAggregator!(BaseService!Interface)
{
/// Shorthand for the client type this instance operates on.
alias TAsyncClientBase!Interface Client;
///
this(Client[] clients) {
super(cast(TAsyncClientBase!(BaseService!Interface)[])clients);
}
mixin AggregatorOpDispatch!();
}
/**
* Whether fun is a valid accumulator function for values of type ValueType.
*
* For this to be true, fun must be a callable matching one of the following
* argument lists:
* ---
* fun(ValueType[] values);
* fun(ValueType[] values, Exception[] exceptions);
* ---
*
* The second version is passed the collected array exceptions from all the
* clients in the pool.
*
* The return value of the accumulator function is passed to the client (via
* the result future). If it throws an exception, the operation is marked as
* failed with the given exception instead.
*/
template isAccumulator(ValueType, alias fun) {
enum isAccumulator = is(typeof(fun(cast(ValueType[])[]))) ||
is(typeof(fun(cast(ValueType[])[], cast(Exception[])[])));
}
/**
* TAsyncAggregator construction helper to avoid having to explicitly
* specify the interface type, i.e. to allow the constructor being called
* using IFTI (see $(DMDBUG 6082, D Bugzilla enhancement request 6082)).
*/
TAsyncAggregator!Interface tAsyncAggregator(Interface)(
TAsyncClientBase!Interface[] clients
) if (isService!Interface) {
return new typeof(return)(clients);
}
private {
mixin template AggregatorOpDispatch() {
auto opDispatch(string name, Args...)(Args args) if (
is(typeof(mixin("Interface.init." ~ name)(args)))
) {
alias ReturnType!(MemberType!(Interface, name)) ResultType;
auto childCancellation = new TCancellationOrigin;
TFuture!ResultType[] futures;
futures.reserve(clients_.length);
foreach (c; cast(Client[])clients_) {
if (reopenTransports) {
if (!c.transport.isOpen) {
try {
c.transport.open();
} catch (Exception e) {
continue;
}
}
}
futures ~= mixin("c." ~ name)(args, childCancellation);
}
return AggregationResult!ResultType(futures, childCancellation);
}
}
struct AggregationResult(T) {
auto opSlice() {
return range();
}
auto range(Duration timeout = dur!"hnsecs"(0)) {
return tFutureAggregatorRange(futures_, childCancellation_, timeout);
}
auto accumulate(alias acc = defaultAccumulator)() if (isAccumulator!(T, acc)) {
return new AccumulatorJob!(T, acc)(futures_, childCancellation_);
}
private:
TFuture!T[] futures_;
TCancellationOrigin childCancellation_;
}
auto defaultAccumulator(T)(T[] values, Exception[] exceptions) {
if (values.empty) {
throw new TCompoundOperationException("All clients failed",
exceptions);
}
static if (is(typeof(T.init ~ T.init))) {
import std.algorithm;
return reduce!"a ~ b"(values);
} else {
return values;
}
}
final class AccumulatorJob(T, alias accumulator) if (
isAccumulator!(T, accumulator)
) : TFuture!(AccumulatorResult!(T, accumulator)) {
this(TFuture!T[] futures, TCancellationOrigin childCancellation) {
futures_ = futures;
childCancellation_ = childCancellation;
resultMutex_ = new Mutex;
completionEvent_ = new TOneshotEvent;
foreach (future; futures) {
future.completion.addCallback({
auto f = future;
return {
synchronized (resultMutex_) {
if (f.status == TFutureStatus.CANCELLED) {
if (!finished_) {
status_ = TFutureStatus.CANCELLED;
finished_ = true;
}
return;
}
if (f.status == TFutureStatus.FAILED) {
exceptions_ ~= f.getException();
} else {
results_ ~= f.get();
}
if (results_.length + exceptions_.length == futures_.length) {
finished_ = true;
completionEvent_.trigger();
}
}
};
}());
}
}
TFutureStatus status() @property {
synchronized (resultMutex_) {
if (!finished_) return TFutureStatus.RUNNING;
if (status_ != TFutureStatus.RUNNING) return status_;
try {
result_ = invokeAccumulator!accumulator(results_, exceptions_);
status_ = TFutureStatus.SUCCEEDED;
} catch (Exception e) {
exception_ = e;
status_ = TFutureStatus.FAILED;
}
return status_;
}
}
TAwaitable completion() @property {
return completionEvent_;
}
AccumulatorResult!(T, accumulator) get() {
auto s = status;
enforce(s != TFutureStatus.RUNNING,
new TFutureException("Operation not yet completed."));
if (s == TFutureStatus.CANCELLED) throw new TCancelledException;
if (s == TFutureStatus.FAILED) throw exception_;
return result_;
}
Exception getException() {
auto s = status;
enforce(s != TFutureStatus.RUNNING,
new TFutureException("Operation not yet completed."));
if (s == TFutureStatus.CANCELLED) throw new TCancelledException;
if (s == TFutureStatus.SUCCEEDED) {
return null;
}
return exception_;
}
void finish() {
synchronized (resultMutex_) {
if (!finished_) {
finished_ = true;
childCancellation_.trigger();
completionEvent_.trigger();
}
}
}
auto finishGet() {
finish();
return get();
}
private:
TFuture!T[] futures_;
TCancellationOrigin childCancellation_;
bool finished_;
T[] results_;
Exception[] exceptions_;
TFutureStatus status_;
Mutex resultMutex_;
union {
AccumulatorResult!(T, accumulator) result_;
Exception exception_;
}
TOneshotEvent completionEvent_;
}
auto invokeAccumulator(alias accumulator, T)(
T[] values, Exception[] exceptions
) if (
isAccumulator!(T, accumulator)
) {
static if (is(typeof(accumulator(values, exceptions)))) {
return accumulator(values, exceptions);
} else {
return accumulator(values);
}
}
template AccumulatorResult(T, alias acc) {
alias typeof(invokeAccumulator!acc(cast(T[])[], cast(Exception[])[]))
AccumulatorResult;
}
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,486 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.codegen.client;
import std.algorithm : find;
import std.array : empty, front;
import std.conv : to;
import std.traits : isSomeFunction, ParameterStorageClass,
ParameterStorageClassTuple, ParameterTypeTuple, ReturnType;
import thrift.codegen.base;
import thrift.internal.codegen;
import thrift.internal.ctfe;
import thrift.protocol.base;
/**
* Thrift service client, which implements an interface by synchronously
* calling a server over a TProtocol.
*
* TClientBase simply extends Interface with generic input/output protocol
* properties to serve as a supertype for all TClients for the same service,
* which might be instantiated with different concrete protocol types (there
* is no covariance for template type parameters). If Interface is derived
* from another interface BaseInterface, it also extends
* TClientBase!BaseInterface.
*
* TClient is the class that actually implements TClientBase. Just as
* TClientBase, it is also derived from TClient!BaseInterface for inheriting
* services.
*
* TClient takes two optional template arguments which can be used for
* specifying the actual TProtocol implementation used for optimization
* purposes, as virtual calls can completely be eliminated then. If
* OutputProtocol is not specified, it is assumed to be the same as
* InputProtocol. The protocol properties defined by TClientBase are exposed
* with their concrete type (return type covariance).
*
* In addition to implementing TClientBase!Interface, TClient offers the
* following constructors:
* ---
* this(InputProtocol iprot, OutputProtocol oprot);
* // Only if is(InputProtocol == OutputProtocol), to use the same protocol
* // for both input and output:
* this(InputProtocol prot);
* ---
*
* The sequence id of the method calls starts at zero and is automatically
* incremented.
*/
interface TClientBase(Interface) if (isBaseService!Interface) : Interface {
/**
* The input protocol used by the client.
*/
TProtocol inputProtocol() @property;
/**
* The output protocol used by the client.
*/
TProtocol outputProtocol() @property;
}
/// Ditto
interface TClientBase(Interface) if (isDerivedService!Interface) :
TClientBase!(BaseService!Interface), Interface {}
/// Ditto
template TClient(Interface, InputProtocol = TProtocol, OutputProtocol = void) if (
isService!Interface && isTProtocol!InputProtocol &&
(isTProtocol!OutputProtocol || is(OutputProtocol == void))
) {
mixin({
static if (isDerivedService!Interface) {
string code = "class TClient : TClient!(BaseService!Interface, " ~
"InputProtocol, OutputProtocol), TClientBase!Interface {\n";
code ~= q{
this(IProt iprot, OProt oprot) {
super(iprot, oprot);
}
static if (is(IProt == OProt)) {
this(IProt prot) {
super(prot);
}
}
// DMD @@BUG@@: If these are not present in this class (would be)
// inherited anyway, »not implemented« errors are raised.
override IProt inputProtocol() @property {
return super.inputProtocol;
}
override OProt outputProtocol() @property {
return super.outputProtocol;
}
};
} else {
string code = "class TClient : TClientBase!Interface {";
code ~= q{
alias InputProtocol IProt;
static if (isTProtocol!OutputProtocol) {
alias OutputProtocol OProt;
} else {
static assert(is(OutputProtocol == void));
alias InputProtocol OProt;
}
this(IProt iprot, OProt oprot) {
iprot_ = iprot;
oprot_ = oprot;
}
static if (is(IProt == OProt)) {
this(IProt prot) {
this(prot, prot);
}
}
IProt inputProtocol() @property {
return iprot_;
}
OProt outputProtocol() @property {
return oprot_;
}
protected IProt iprot_;
protected OProt oprot_;
protected int seqid_;
};
}
foreach (methodName; __traits(derivedMembers, Interface)) {
static if (isSomeFunction!(mixin("Interface." ~ methodName))) {
bool methodMetaFound;
TMethodMeta methodMeta;
static if (is(typeof(Interface.methodMeta) : TMethodMeta[])) {
enum meta = find!`a.name == b`(Interface.methodMeta, methodName);
if (!meta.empty) {
methodMetaFound = true;
methodMeta = meta.front;
}
}
// Generate the code for sending.
string[] paramList;
string paramAssignCode;
foreach (i, _; ParameterTypeTuple!(mixin("Interface." ~ methodName))) {
// Use the param name speficied in the meta information if any
// just cosmetics in this case.
string paramName;
if (methodMetaFound && i < methodMeta.params.length) {
paramName = methodMeta.params[i].name;
} else {
paramName = "param" ~ to!string(i + 1);
}
immutable storage = ParameterStorageClassTuple!(
mixin("Interface." ~ methodName))[i];
paramList ~= ((storage & ParameterStorageClass.ref_) ? "ref " : "") ~
"ParameterTypeTuple!(Interface." ~ methodName ~ ")[" ~
to!string(i) ~ "] " ~ paramName;
paramAssignCode ~= "args." ~ paramName ~ " = &" ~ paramName ~ ";\n";
}
code ~= "ReturnType!(Interface." ~ methodName ~ ") " ~ methodName ~
"(" ~ ctfeJoin(paramList) ~ ") {\n";
code ~= "immutable methodName = `" ~ methodName ~ "`;\n";
immutable paramStructType =
"TPargsStruct!(Interface, `" ~ methodName ~ "`)";
code ~= paramStructType ~ " args = " ~ paramStructType ~ "();\n";
code ~= paramAssignCode;
code ~= "oprot_.writeMessageBegin(TMessage(`" ~ methodName ~ "`, ";
code ~= ((methodMetaFound && methodMeta.type == TMethodType.ONEWAY)
? "TMessageType.ONEWAY" : "TMessageType.CALL");
code ~= ", ++seqid_));\n";
code ~= "args.write(oprot_);\n";
code ~= "oprot_.writeMessageEnd();\n";
code ~= "oprot_.transport.flush();\n";
// If this is not a oneway method, generate the receiving code.
if (!methodMetaFound || methodMeta.type != TMethodType.ONEWAY) {
code ~= "TPresultStruct!(Interface, `" ~ methodName ~ "`) result;\n";
if (!is(ReturnType!(mixin("Interface." ~ methodName)) == void)) {
code ~= "ReturnType!(Interface." ~ methodName ~ ") _return;\n";
code ~= "result.success = &_return;\n";
}
// TODO: The C++ implementation checks for matching method name here,
// should we do as well?
code ~= q{
auto msg = iprot_.readMessageBegin();
scope (exit) {
iprot_.readMessageEnd();
iprot_.transport.readEnd();
}
if (msg.type == TMessageType.EXCEPTION) {
auto x = new TApplicationException(null);
x.read(iprot_);
iprot_.transport.readEnd();
throw x;
}
if (msg.type != TMessageType.REPLY) {
skip(iprot_, TType.STRUCT);
iprot_.transport.readEnd();
}
if (msg.seqid != seqid_) {
throw new TApplicationException(
methodName ~ " failed: Out of sequence response.",
TApplicationException.Type.BAD_SEQUENCE_ID
);
}
result.read(iprot_);
};
if (methodMetaFound) {
foreach (e; methodMeta.exceptions) {
code ~= "if (result.isSet!`" ~ e.name ~ "`) throw result." ~
e.name ~ ";\n";
}
}
if (!is(ReturnType!(mixin("Interface." ~ methodName)) == void)) {
code ~= q{
if (result.isSet!`success`) return _return;
throw new TApplicationException(
methodName ~ " failed: Unknown result.",
TApplicationException.Type.MISSING_RESULT
);
};
}
}
code ~= "}\n";
}
}
code ~= "}\n";
return code;
}());
}
/**
* TClient construction helper to avoid having to explicitly specify
* the protocol types, i.e. to allow the constructor being called using IFTI
* (see $(DMDBUG 6082, D Bugzilla enhancement requet 6082)).
*/
TClient!(Interface, Prot) tClient(Interface, Prot)(Prot prot) if (
isService!Interface && isTProtocol!Prot
) {
return new TClient!(Interface, Prot)(prot);
}
/// Ditto
TClient!(Interface, IProt, Oprot) tClient(Interface, IProt, OProt)
(IProt iprot, OProt oprot) if (
isService!Interface && isTProtocol!IProt && isTProtocol!OProt
) {
return new TClient!(Interface, IProt, OProt)(iprot, oprot);
}
/**
* Represents the arguments of a Thrift method call, as pointers to the (const)
* parameter type to avoid copying.
*
* There should usually be no reason to use this struct directly without the
* help of TClient, but it is documented publicly to help debugging in case
* of CTFE errors.
*
* Consider this example:
* ---
* interface Foo {
* int bar(string a, bool b);
*
* enum methodMeta = [
* TMethodMeta("bar", [TParamMeta("a", 1), TParamMeta("b", 2)])
* ];
* }
*
* alias TPargsStruct!(Foo, "bar") FooBarPargs;
* ---
*
* The definition of FooBarPargs is equivalent to (ignoring the necessary
* metadata to assign the field IDs):
* ---
* struct FooBarPargs {
* const(string)* a;
* const(bool)* b;
*
* void write(Protocol)(Protocol proto) const if (isTProtocol!Protocol);
* }
* ---
*/
template TPargsStruct(Interface, string methodName) {
static assert(is(typeof(mixin("Interface." ~ methodName))),
"Could not find method '" ~ methodName ~ "' in '" ~ Interface.stringof ~ "'.");
mixin({
bool methodMetaFound;
TMethodMeta methodMeta;
static if (is(typeof(Interface.methodMeta) : TMethodMeta[])) {
auto meta = find!`a.name == b`(Interface.methodMeta, methodName);
if (!meta.empty) {
methodMetaFound = true;
methodMeta = meta.front;
}
}
string memberCode;
string[] fieldMetaCodes;
foreach (i, _; ParameterTypeTuple!(mixin("Interface." ~ methodName))) {
// If we have no meta information, just use param1, param2, etc. as
// field names, it shouldn't really matter anyway. 1-based »indexing«
// is used to match the common scheme in the Thrift world.
string memberId;
string memberName;
if (methodMetaFound && i < methodMeta.params.length) {
memberId = to!string(methodMeta.params[i].id);
memberName = methodMeta.params[i].name;
} else {
memberId = to!string(i + 1);
memberName = "param" ~ to!string(i + 1);
}
// Workaround for DMD @@BUG@@ 6056: make an intermediary alias for the
// parameter type, and declare the member using const(memberNameType)*.
memberCode ~= "alias ParameterTypeTuple!(Interface." ~ methodName ~
")[" ~ to!string(i) ~ "] " ~ memberName ~ "Type;\n";
memberCode ~= "const(" ~ memberName ~ "Type)* " ~ memberName ~ ";\n";
fieldMetaCodes ~= "TFieldMeta(`" ~ memberName ~ "`, " ~ memberId ~
", TReq.OPT_IN_REQ_OUT)";
}
string code = "struct TPargsStruct {\n";
code ~= memberCode;
version (TVerboseCodegen) {
if (!methodMetaFound &&
ParameterTypeTuple!(mixin("Interface." ~ methodName)).length > 0)
{
code ~= "pragma(msg, `[thrift.codegen.base.TPargsStruct] Warning: No " ~
"meta information for method '" ~ methodName ~ "' in service '" ~
Interface.stringof ~ "' found.`);\n";
}
}
code ~= "void write(P)(P proto) const if (isTProtocol!P) {\n";
code ~= "writeStruct!(typeof(this), P, [" ~ ctfeJoin(fieldMetaCodes) ~
"], true)(this, proto);\n";
code ~= "}\n";
code ~= "}\n";
return code;
}());
}
/**
* Represents the result of a Thrift method call, using a pointer to the return
* value to avoid copying.
*
* There should usually be no reason to use this struct directly without the
* help of TClient, but it is documented publicly to help debugging in case
* of CTFE errors.
*
* Consider this example:
* ---
* interface Foo {
* int bar(string a);
*
* alias .FooException FooException;
*
* enum methodMeta = [
* TMethodMeta("bar",
* [TParamMeta("a", 1)],
* [TExceptionMeta("fooe", 1, "FooException")]
* )
* ];
* }
* alias TPresultStruct!(Foo, "bar") FooBarPresult;
* ---
*
* The definition of FooBarPresult is equivalent to (ignoring the necessary
* metadata to assign the field IDs):
* ---
* struct FooBarPresult {
* int* success;
* Foo.FooException fooe;
*
* struct IsSetFlags {
* bool success;
* }
* IsSetFlags isSetFlags;
*
* bool isSet(string fieldName)() const @property;
* void read(Protocol)(Protocol proto) if (isTProtocol!Protocol);
* }
* ---
*/
template TPresultStruct(Interface, string methodName) {
static assert(is(typeof(mixin("Interface." ~ methodName))),
"Could not find method '" ~ methodName ~ "' in '" ~ Interface.stringof ~ "'.");
mixin({
string code = "struct TPresultStruct {\n";
string[] fieldMetaCodes;
alias ReturnType!(mixin("Interface." ~ methodName)) ResultType;
static if (!is(ResultType == void)) {
code ~= q{
ReturnType!(mixin("Interface." ~ methodName))* success;
};
fieldMetaCodes ~= "TFieldMeta(`success`, 0, TReq.OPTIONAL)";
static if (!isNullable!ResultType) {
code ~= q{
struct IsSetFlags {
bool success;
}
IsSetFlags isSetFlags;
};
fieldMetaCodes ~= "TFieldMeta(`isSetFlags`, 0, TReq.IGNORE)";
}
}
bool methodMetaFound;
static if (is(typeof(Interface.methodMeta) : TMethodMeta[])) {
auto meta = find!`a.name == b`(Interface.methodMeta, methodName);
if (!meta.empty) {
foreach (e; meta.front.exceptions) {
code ~= "Interface." ~ e.type ~ " " ~ e.name ~ ";\n";
fieldMetaCodes ~= "TFieldMeta(`" ~ e.name ~ "`, " ~ to!string(e.id) ~
", TReq.OPTIONAL)";
}
methodMetaFound = true;
}
}
version (TVerboseCodegen) {
if (!methodMetaFound &&
ParameterTypeTuple!(mixin("Interface." ~ methodName)).length > 0)
{
code ~= "pragma(msg, `[thrift.codegen.base.TPresultStruct] Warning: No " ~
"meta information for method '" ~ methodName ~ "' in service '" ~
Interface.stringof ~ "' found.`);\n";
}
}
code ~= q{
bool isSet(string fieldName)() const @property if (
is(MemberType!(typeof(this), fieldName))
) {
static if (fieldName == "success") {
static if (isNullable!(typeof(*success))) {
return *success !is null;
} else {
return isSetFlags.success;
}
} else {
// We are dealing with an exception member, which, being a nullable
// type (exceptions are always classes), has no isSet flag.
return __traits(getMember, this, fieldName) !is null;
}
}
};
code ~= "void read(P)(P proto) if (isTProtocol!P) {\n";
code ~= "readStruct!(typeof(this), P, [" ~ ctfeJoin(fieldMetaCodes) ~
"], true)(this, proto);\n";
code ~= "}\n";
code ~= "}\n";
return code;
}());
}

View file

@ -0,0 +1,262 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.codegen.client_pool;
import core.time : dur, Duration, TickDuration;
import std.traits : ParameterTypeTuple, ReturnType;
import thrift.base;
import thrift.codegen.base;
import thrift.codegen.client;
import thrift.internal.codegen;
import thrift.internal.resource_pool;
/**
* Manages a pool of TClients for the given interface, forwarding RPC calls to
* members of the pool.
*
* If a request fails, another client from the pool is tried, and optionally,
* a client is disabled for a configurable amount of time if it fails too
* often. If all clients fail (and keepTrying is false), a
* TCompoundOperationException is thrown, containing all the collected RPC
* exceptions.
*/
class TClientPool(Interface) if (isService!Interface) : Interface {
/// Shorthand for TClientBase!Interface, the client type this instance
/// operates on.
alias TClientBase!Interface Client;
/**
* Creates a new instance and adds the given clients to the pool.
*/
this(Client[] clients) {
pool_ = new TResourcePool!Client(clients);
rpcFaultFilter = (Exception e) {
import thrift.protocol.base;
import thrift.transport.base;
return (
(cast(TTransportException)e !is null) ||
(cast(TApplicationException)e !is null)
);
};
}
/**
* Executes an operation on the first currently active client.
*
* If the operation fails (throws an exception for which rpcFaultFilter is
* true), the failure is recorded and the next client in the pool is tried.
*
* Throws: Any non-rpc exception that occurs, a TCompoundOperationException
* if all clients failed with an rpc exception (if keepTrying is false).
*
* Example:
* ---
* interface Foo { string bar(); }
* auto poolClient = tClientPool([tClient!Foo(someProtocol)]);
* auto result = poolClient.execute((c){ return c.bar(); });
* ---
*/
ResultType execute(ResultType)(scope ResultType delegate(Client) work) {
return executeOnPool!Client(work);
}
/**
* Adds a client to the pool.
*/
void addClient(Client client) {
pool_.add(client);
}
/**
* Removes a client from the pool.
*
* Returns: Whether the client was found in the pool.
*/
bool removeClient(Client client) {
return pool_.remove(client);
}
mixin(poolForwardCode!Interface());
/// Whether to open the underlying transports of a client before trying to
/// execute a method if they are not open. This is usually desirable
/// because it allows e.g. to automatically reconnect to a remote server
/// if the network connection is dropped.
///
/// Defaults to true.
bool reopenTransports = true;
/// Called to determine whether an exception comes from a client from the
/// pool not working properly, or if it an exception thrown at the
/// application level.
///
/// If the delegate returns true, the server/connection is considered to be
/// at fault, if it returns false, the exception is just passed on to the
/// caller.
///
/// By default, returns true for instances of TTransportException and
/// TApplicationException, false otherwise.
bool delegate(Exception) rpcFaultFilter;
/**
* Whether to keep trying to find a working client if all have failed in a
* row.
*
* Defaults to false.
*/
bool keepTrying() const @property {
return pool_.cycle;
}
/// Ditto
void keepTrying(bool value) @property {
pool_.cycle = value;
}
/**
* Whether to use a random permutation of the client pool on every call to
* execute(). This can be used e.g. as a simple form of load balancing.
*
* Defaults to true.
*/
bool permuteClients() const @property {
return pool_.permute;
}
/// Ditto
void permuteClients(bool value) @property {
pool_.permute = value;
}
/**
* The number of consecutive faults after which a client is disabled until
* faultDisableDuration has passed. 0 to never disable clients.
*
* Defaults to 0.
*/
ushort faultDisableCount() @property {
return pool_.faultDisableCount;
}
/// Ditto
void faultDisableCount(ushort value) @property {
pool_.faultDisableCount = value;
}
/**
* The duration for which a client is no longer considered after it has
* failed too often.
*
* Defaults to one second.
*/
Duration faultDisableDuration() @property {
return pool_.faultDisableDuration;
}
/// Ditto
void faultDisableDuration(Duration value) @property {
pool_.faultDisableDuration = value;
}
protected:
ResultType executeOnPool(ResultType)(scope ResultType delegate(Client) work) {
auto clients = pool_[];
if (clients.empty) {
throw new TException("No clients available to try.");
}
while (true) {
Exception[] rpcExceptions;
while (!clients.empty) {
auto c = clients.front;
clients.popFront;
try {
scope (success) {
pool_.recordSuccess(c);
}
if (reopenTransports) {
c.inputProtocol.transport.open();
c.outputProtocol.transport.open();
}
return work(c);
} catch (Exception e) {
if (rpcFaultFilter && rpcFaultFilter(e)) {
pool_.recordFault(c);
rpcExceptions ~= e;
} else {
// We are dealing with a normal exception thrown by the
// server-side method, just pass it on. As far as we are
// concerned, the method call succeeded.
pool_.recordSuccess(c);
throw e;
}
}
}
// If we get here, no client succeeded during the current iteration.
Duration waitTime;
Client dummy;
if (clients.willBecomeNonempty(dummy, waitTime)) {
if (waitTime > dur!"hnsecs"(0)) {
import core.thread;
Thread.sleep(waitTime);
}
} else {
throw new TCompoundOperationException("All clients failed.",
rpcExceptions);
}
}
}
private:
TResourcePool!Client pool_;
}
private {
// Cannot use an anonymous delegate literal for this because they aren't
// allowed in class scope.
string poolForwardCode(Interface)() {
string code = "";
foreach (methodName; AllMemberMethodNames!Interface) {
enum qn = "Interface." ~ methodName;
code ~= "ReturnType!(" ~ qn ~ ") " ~ methodName ~
"(ParameterTypeTuple!(" ~ qn ~ ") args) {\n";
code ~= "return executeOnPool((Client c){ return c." ~
methodName ~ "(args); });\n";
code ~= "}\n";
}
return code;
}
}
/**
* TClientPool construction helper to avoid having to explicitly specify
* the interface type, i.e. to allow the constructor being called using IFTI
* (see $(DMDBUG 6082, D Bugzilla enhancement requet 6082)).
*/
TClientPool!Interface tClientPool(Interface)(
TClientBase!Interface[] clients
) if (isService!Interface) {
return new typeof(return)(clients);
}

View file

@ -0,0 +1,770 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/**
* Contains <b>experimental</b> functionality for generating Thrift IDL files
* (.thrift) from existing D data structures, i.e. the reverse of what the
* Thrift compiler does.
*/
module thrift.codegen.idlgen;
import std.algorithm : find;
import std.array : empty, front;
import std.conv : to;
import std.traits : EnumMembers, isSomeFunction, OriginalType,
ParameterTypeTuple, ReturnType;
import std.typetuple : allSatisfy, staticIndexOf, staticMap, NoDuplicates,
TypeTuple;
import thrift.base;
import thrift.codegen.base;
import thrift.internal.codegen;
import thrift.internal.ctfe;
import thrift.util.hashset;
/**
* True if the passed type is a Thrift entity (struct, exception, enum,
* service).
*/
alias Any!(isStruct, isException, isEnum, isService) isThriftEntity;
/**
* Returns an IDL string describing the passed »root« entities and all types
* they depend on.
*/
template idlString(Roots...) if (allSatisfy!(isThriftEntity, Roots)) {
enum idlString = idlStringImpl!Roots.result;
}
private {
template idlStringImpl(Roots...) if (allSatisfy!(isThriftEntity, Roots)) {
alias ForAllWithList!(
ConfinedTuple!(StaticFilter!(isService, Roots)),
AddBaseServices
) Services;
alias TypeTuple!(
StaticFilter!(isEnum, Roots),
ForAllWithList!(
ConfinedTuple!(
StaticFilter!(Any!(isException, isStruct), Roots),
staticMap!(CompositeTypeDeps, staticMap!(ServiceTypeDeps, Services))
),
AddStructWithDeps
)
) Types;
enum result = ctfeJoin(
[
staticMap!(
enumIdlString,
StaticFilter!(isEnum, Types)
),
staticMap!(
structIdlString,
StaticFilter!(Any!(isStruct, isException), Types)
),
staticMap!(
serviceIdlString,
Services
)
],
"\n"
);
}
template ServiceTypeDeps(T) if (isService!T) {
alias staticMap!(
PApply!(MethodTypeDeps, T),
FilterMethodNames!(T, __traits(derivedMembers, T))
) ServiceTypeDeps;
}
template MethodTypeDeps(T, string name) if (
isService!T && isSomeFunction!(MemberType!(T, name))
) {
alias TypeTuple!(
ReturnType!(MemberType!(T, name)),
ParameterTypeTuple!(MemberType!(T, name)),
ExceptionTypes!(T, name)
) MethodTypeDeps;
}
template ExceptionTypes(T, string name) if (
isService!T && isSomeFunction!(MemberType!(T, name))
) {
mixin({
enum meta = find!`a.name == b`(getMethodMeta!T, name);
if (meta.empty) return "alias TypeTuple!() ExceptionTypes;";
string result = "alias TypeTuple!(";
foreach (i, e; meta.front.exceptions) {
if (i > 0) result ~= ", ";
result ~= "mixin(`T." ~ e.type ~ "`)";
}
result ~= ") ExceptionTypes;";
return result;
}());
}
template AddBaseServices(T, List...) {
static if (staticIndexOf!(T, List) == -1) {
alias NoDuplicates!(BaseServices!T, List) AddBaseServices;
} else {
alias List AddStructWithDeps;
}
}
unittest {
interface A {}
interface B : A {}
interface C : B {}
interface D : A {}
static assert(is(AddBaseServices!(C) == TypeTuple!(A, B, C)));
static assert(is(ForAllWithList!(ConfinedTuple!(C, D), AddBaseServices) ==
TypeTuple!(A, D, B, C)));
}
template BaseServices(T, Rest...) if (isService!T) {
static if (isDerivedService!T) {
alias BaseServices!(BaseService!T, T, Rest) BaseServices;
} else {
alias TypeTuple!(T, Rest) BaseServices;
}
}
template AddStructWithDeps(T, List...) {
static if (staticIndexOf!(T, List) == -1) {
// T is not already in the List, so add T and the types it depends on in
// the front. Because with the Thrift compiler types can only depend on
// other types that have already been defined, we collect all the
// dependencies, prepend them to the list, and then prune the duplicates
// (keeping the first occurrences). If this requirement should ever be
// dropped from Thrift, this could be easily adapted to handle circular
// dependencies by passing TypeTuple!(T, List) to ForAllWithList instead
// of appending List afterwards, and removing the now unnecessary
// NoDuplicates.
alias NoDuplicates!(
ForAllWithList!(
ConfinedTuple!(
staticMap!(
CompositeTypeDeps,
staticMap!(
PApply!(MemberType, T),
FieldNames!T
)
)
),
.AddStructWithDeps,
T
),
List
) AddStructWithDeps;
} else {
alias List AddStructWithDeps;
}
}
version (unittest) {
struct A {}
struct B {
A a;
int b;
A c;
string d;
}
struct C {
B b;
A a;
}
static assert(is(AddStructWithDeps!C == TypeTuple!(A, B, C)));
struct D {
C c;
mixin TStructHelpers!([TFieldMeta("c", 0, TReq.IGNORE)]);
}
static assert(is(AddStructWithDeps!D == TypeTuple!(D)));
}
version (unittest) {
// Circles in the type dependency graph are not allowed in Thrift, but make
// sure we fail in a sane way instead of crashing the compiler.
struct Rec1 {
Rec2[] other;
}
struct Rec2 {
Rec1[] other;
}
static assert(!__traits(compiles, AddStructWithDeps!Rec1));
}
/*
* Returns the non-primitive types T directly depends on.
*
* For example, CompositeTypeDeps!int would yield an empty type tuple,
* CompositeTypeDeps!SomeStruct would give SomeStruct, and
* CompositeTypeDeps!(A[B]) both CompositeTypeDeps!A and CompositeTypeDeps!B.
*/
template CompositeTypeDeps(T) {
static if (is(FullyUnqual!T == bool) || is(FullyUnqual!T == byte) ||
is(FullyUnqual!T == short) || is(FullyUnqual!T == int) ||
is(FullyUnqual!T == long) || is(FullyUnqual!T : string) ||
is(FullyUnqual!T == double) || is(FullyUnqual!T == void)
) {
alias TypeTuple!() CompositeTypeDeps;
} else static if (is(FullyUnqual!T _ : U[], U)) {
alias CompositeTypeDeps!U CompositeTypeDeps;
} else static if (is(FullyUnqual!T _ : HashSet!E, E)) {
alias CompositeTypeDeps!E CompositeTypeDeps;
} else static if (is(FullyUnqual!T _ : V[K], K, V)) {
alias TypeTuple!(CompositeTypeDeps!K, CompositeTypeDeps!V) CompositeTypeDeps;
} else static if (is(FullyUnqual!T == enum) || is(FullyUnqual!T == struct) ||
is(FullyUnqual!T : TException)
) {
alias TypeTuple!(FullyUnqual!T) CompositeTypeDeps;
} else {
static assert(false, "Cannot represent type in Thrift: " ~ T.stringof);
}
}
}
/**
* Returns an IDL string describing the passed service. IDL code for any type
* dependcies is not included.
*/
template serviceIdlString(T) if (isService!T) {
enum serviceIdlString = {
string result = "service " ~ T.stringof;
static if (isDerivedService!T) {
result ~= " extends " ~ BaseService!T.stringof;
}
result ~= " {\n";
foreach (methodName; FilterMethodNames!(T, __traits(derivedMembers, T))) {
result ~= " ";
enum meta = find!`a.name == b`(T.methodMeta, methodName);
static if (!meta.empty && meta.front.type == TMethodType.ONEWAY) {
result ~= "oneway ";
}
alias ReturnType!(MemberType!(T, methodName)) RT;
static if (is(RT == void)) {
// We special-case this here instead of adding void to dToIdlType to
// avoid accepting things like void[].
result ~= "void ";
} else {
result ~= dToIdlType!RT ~ " ";
}
result ~= methodName ~ "(";
short lastId;
foreach (i, ParamType; ParameterTypeTuple!(MemberType!(T, methodName))) {
static if (!meta.empty && i < meta.front.params.length) {
enum havePM = true;
} else {
enum havePM = false;
}
short id;
static if (havePM) {
id = meta.front.params[i].id;
} else {
id = --lastId;
}
string paramName;
static if (havePM) {
paramName = meta.front.params[i].name;
} else {
paramName = "param" ~ to!string(i + 1);
}
result ~= to!string(id) ~ ": " ~ dToIdlType!ParamType ~ " " ~ paramName;
static if (havePM && !meta.front.params[i].defaultValue.empty) {
result ~= " = " ~ dToIdlConst(mixin(meta.front.params[i].defaultValue));
} else {
// Unfortunately, getting the default value for parameters from a
// function alias isn't possible we can't transfer the default
// value to the IDL e.g. for interface Foo { void foo(int a = 5); }
// without the user explicitly declaring it in metadata.
}
result ~= ", ";
}
result ~= ")";
static if (!meta.empty && !meta.front.exceptions.empty) {
result ~= " throws (";
foreach (e; meta.front.exceptions) {
result ~= to!string(e.id) ~ ": " ~ e.type ~ " " ~ e.name ~ ", ";
}
result ~= ")";
}
result ~= ",\n";
}
result ~= "}\n";
return result;
}();
}
/**
* Returns an IDL string describing the passed enum. IDL code for any type
* dependcies is not included.
*/
template enumIdlString(T) if (isEnum!T) {
enum enumIdlString = {
static assert(is(OriginalType!T : long),
"Can only have integer enums in Thrift (not " ~ OriginalType!T.stringof ~
", for " ~ T.stringof ~ ").");
string result = "enum " ~ T.stringof ~ " {\n";
foreach (name; __traits(derivedMembers, T)) {
result ~= " " ~ name ~ " = " ~ dToIdlConst(GetMember!(T, name)) ~ ",\n";
}
result ~= "}\n";
return result;
}();
}
/**
* Returns an IDL string describing the passed struct. IDL code for any type
* dependcies is not included.
*/
template structIdlString(T) if (isStruct!T || isException!T) {
enum structIdlString = {
mixin({
string code = "";
foreach (field; getFieldMeta!T) {
code ~= "static assert(is(MemberType!(T, `" ~ field.name ~ "`)));\n";
}
return code;
}());
string result;
static if (isException!T) {
result = "exception ";
} else {
result = "struct ";
}
result ~= T.stringof ~ " {\n";
// The last automatically assigned id fields with no meta information
// are assigned (in lexical order) descending negative ids, starting with
// -1, just like the Thrift compiler does.
short lastId;
foreach (name; FieldNames!T) {
enum meta = find!`a.name == b`(getFieldMeta!T, name);
static if (meta.empty || meta.front.req != TReq.IGNORE) {
short id;
static if (meta.empty) {
id = --lastId;
} else {
id = meta.front.id;
}
result ~= " " ~ to!string(id) ~ ":";
static if (!meta.empty) {
result ~= dToIdlReq(meta.front.req);
}
result ~= " " ~ dToIdlType!(MemberType!(T, name)) ~ " " ~ name;
static if (!meta.empty && !meta.front.defaultValue.empty) {
result ~= " = " ~ dToIdlConst(mixin(meta.front.defaultValue));
} else static if (__traits(compiles, fieldInitA!(T, name))) {
static if (is(typeof(fieldInitA!(T, name))) &&
!is(typeof(fieldInitA!(T, name)) == void)
) {
result ~= " = " ~ dToIdlConst(fieldInitA!(T, name));
}
} else static if (is(typeof(fieldInitB!(T, name))) &&
!is(typeof(fieldInitB!(T, name)) == void)
) {
result ~= " = " ~ dToIdlConst(fieldInitB!(T, name));
}
result ~= ",\n";
}
}
result ~= "}\n";
return result;
}();
}
private {
// This very convoluted way of doing things was chosen because putting the
// static if directly into structIdlString caused »not evaluatable at compile
// time« errors to slip through even though typeof() was used, resp. the
// condition to be true even though the value couldn't actually be read at
// compile time due to a @@BUG@@ in DMD 2.055.
// The extra »compiled« field in fieldInitA is needed because we must not try
// to use != if !is compiled as well (but was false), e.g. for floating point
// types.
template fieldInitA(T, string name) {
static if (mixin("T.init." ~ name) !is MemberType!(T, name).init) {
enum fieldInitA = mixin("T.init." ~ name);
}
}
template fieldInitB(T, string name) {
static if (mixin("T.init." ~ name) != MemberType!(T, name).init) {
enum fieldInitB = mixin("T.init." ~ name);
}
}
template dToIdlType(T) {
static if (is(FullyUnqual!T == bool)) {
enum dToIdlType = "bool";
} else static if (is(FullyUnqual!T == byte)) {
enum dToIdlType = "byte";
} else static if (is(FullyUnqual!T == double)) {
enum dToIdlType = "double";
} else static if (is(FullyUnqual!T == short)) {
enum dToIdlType = "i16";
} else static if (is(FullyUnqual!T == int)) {
enum dToIdlType = "i32";
} else static if (is(FullyUnqual!T == long)) {
enum dToIdlType = "i64";
} else static if (is(FullyUnqual!T : string)) {
enum dToIdlType = "string";
} else static if (is(FullyUnqual!T _ : U[], U)) {
enum dToIdlType = "list<" ~ dToIdlType!U ~ ">";
} else static if (is(FullyUnqual!T _ : V[K], K, V)) {
enum dToIdlType = "map<" ~ dToIdlType!K ~ ", " ~ dToIdlType!V ~ ">";
} else static if (is(FullyUnqual!T _ : HashSet!E, E)) {
enum dToIdlType = "set<" ~ dToIdlType!E ~ ">";
} else static if (is(FullyUnqual!T == struct) || is(FullyUnqual!T == enum) ||
is(FullyUnqual!T : TException)
) {
enum dToIdlType = FullyUnqual!(T).stringof;
} else {
static assert(false, "Cannot represent type in Thrift: " ~ T.stringof);
}
}
string dToIdlReq(TReq req) {
switch (req) {
case TReq.REQUIRED: return " required";
case TReq.OPTIONAL: return " optional";
default: return "";
}
}
string dToIdlConst(T)(T value) {
static if (is(FullyUnqual!T == bool)) {
return value ? "1" : "0";
} else static if (is(FullyUnqual!T == byte) ||
is(FullyUnqual!T == short) || is(FullyUnqual!T == int) ||
is(FullyUnqual!T == long)
) {
return to!string(value);
} else static if (is(FullyUnqual!T : string)) {
return `"` ~ to!string(value) ~ `"`;
} else static if (is(FullyUnqual!T == double)) {
return ctfeToString(value);
} else static if (is(FullyUnqual!T _ : U[], U) ||
is(FullyUnqual!T _ : HashSet!E, E)
) {
string result = "[";
foreach (e; value) {
result ~= dToIdlConst(e) ~ ", ";
}
result ~= "]";
return result;
} else static if (is(FullyUnqual!T _ : V[K], K, V)) {
string result = "{";
foreach (key, val; value) {
result ~= dToIdlConst(key) ~ ": " ~ dToIdlConst(val) ~ ", ";
}
result ~= "}";
return result;
} else static if (is(FullyUnqual!T == enum)) {
import std.conv;
import std.traits;
return to!string(cast(OriginalType!T)value);
} else static if (is(FullyUnqual!T == struct) ||
is(FullyUnqual!T : TException)
) {
string result = "{";
foreach (name; __traits(derivedMembers, T)) {
static if (memberReq!(T, name) != TReq.IGNORE) {
result ~= name ~ ": " ~ dToIdlConst(mixin("value." ~ name)) ~ ", ";
}
}
result ~= "}";
return result;
} else {
static assert(false, "Cannot represent type in Thrift: " ~ T.stringof);
}
}
}
version (unittest) {
enum Foo {
a = 1,
b = 10,
c = 5
}
static assert(enumIdlString!Foo ==
`enum Foo {
a = 1,
b = 10,
c = 5,
}
`);
}
version (unittest) {
struct WithoutMeta {
string a;
int b;
}
struct WithDefaults {
string a = "asdf";
double b = 3.1415;
WithoutMeta c;
mixin TStructHelpers!([
TFieldMeta("c", 1, TReq.init, `WithoutMeta("foo", 3)`)
]);
}
// These are from DebugProtoTest.thrift.
struct OneOfEach {
bool im_true;
bool im_false;
byte a_bite;
short integer16;
int integer32;
long integer64;
double double_precision;
string some_characters;
string zomg_unicode;
bool what_who;
string base64;
byte[] byte_list;
short[] i16_list;
long[] i64_list;
mixin TStructHelpers!([
TFieldMeta(`im_true`, 1),
TFieldMeta(`im_false`, 2),
TFieldMeta(`a_bite`, 3, TReq.OPT_IN_REQ_OUT, q{cast(byte)127}),
TFieldMeta(`integer16`, 4, TReq.OPT_IN_REQ_OUT, q{cast(short)32767}),
TFieldMeta(`integer32`, 5),
TFieldMeta(`integer64`, 6, TReq.OPT_IN_REQ_OUT, q{10000000000L}),
TFieldMeta(`double_precision`, 7),
TFieldMeta(`some_characters`, 8),
TFieldMeta(`zomg_unicode`, 9),
TFieldMeta(`what_who`, 10),
TFieldMeta(`base64`, 11),
TFieldMeta(`byte_list`, 12, TReq.OPT_IN_REQ_OUT, q{{
byte[] v;
v ~= cast(byte)1;
v ~= cast(byte)2;
v ~= cast(byte)3;
return v;
}()}),
TFieldMeta(`i16_list`, 13, TReq.OPT_IN_REQ_OUT, q{{
short[] v;
v ~= cast(short)1;
v ~= cast(short)2;
v ~= cast(short)3;
return v;
}()}),
TFieldMeta(`i64_list`, 14, TReq.OPT_IN_REQ_OUT, q{{
long[] v;
v ~= 1L;
v ~= 2L;
v ~= 3L;
return v;
}()})
]);
}
struct Bonk {
int type;
string message;
mixin TStructHelpers!([
TFieldMeta(`type`, 1),
TFieldMeta(`message`, 2)
]);
}
struct HolyMoley {
OneOfEach[] big;
HashSet!(string[]) contain;
Bonk[][string] bonks;
mixin TStructHelpers!([
TFieldMeta(`big`, 1),
TFieldMeta(`contain`, 2),
TFieldMeta(`bonks`, 3)
]);
}
static assert(structIdlString!WithoutMeta ==
`struct WithoutMeta {
-1: string a,
-2: i32 b,
}
`);
import std.algorithm;
static assert(structIdlString!WithDefaults.startsWith(
`struct WithDefaults {
-1: string a = "asdf",
-2: double b = 3.141`));
static assert(structIdlString!WithDefaults.endsWith(
`1: WithoutMeta c = {a: "foo", b: 3, },
}
`));
static assert(structIdlString!OneOfEach ==
`struct OneOfEach {
1: bool im_true,
2: bool im_false,
3: byte a_bite = 127,
4: i16 integer16 = 32767,
5: i32 integer32,
6: i64 integer64 = 10000000000,
7: double double_precision,
8: string some_characters,
9: string zomg_unicode,
10: bool what_who,
11: string base64,
12: list<byte> byte_list = [1, 2, 3, ],
13: list<i16> i16_list = [1, 2, 3, ],
14: list<i64> i64_list = [1, 2, 3, ],
}
`);
static assert(structIdlString!Bonk ==
`struct Bonk {
1: i32 type,
2: string message,
}
`);
static assert(structIdlString!HolyMoley ==
`struct HolyMoley {
1: list<OneOfEach> big,
2: set<list<string>> contain,
3: map<string, list<Bonk>> bonks,
}
`);
}
version (unittest) {
class ExceptionWithAMap : TException {
string blah;
string[string] map_field;
mixin TStructHelpers!([
TFieldMeta(`blah`, 1),
TFieldMeta(`map_field`, 2)
]);
}
interface Srv {
void voidMethod();
int primitiveMethod();
OneOfEach structMethod();
void methodWithDefaultArgs(int something);
void onewayMethod();
void exceptionMethod();
alias .ExceptionWithAMap ExceptionWithAMap;
enum methodMeta = [
TMethodMeta(`methodWithDefaultArgs`,
[TParamMeta(`something`, 1, q{2})]
),
TMethodMeta(`onewayMethod`,
[],
[],
TMethodType.ONEWAY
),
TMethodMeta(`exceptionMethod`,
[],
[
TExceptionMeta("a", 1, "ExceptionWithAMap"),
TExceptionMeta("b", 2, "ExceptionWithAMap")
]
)
];
}
interface ChildSrv : Srv {
int childMethod(int arg);
}
static assert(idlString!ChildSrv ==
`exception ExceptionWithAMap {
1: string blah,
2: map<string, string> map_field,
}
struct OneOfEach {
1: bool im_true,
2: bool im_false,
3: byte a_bite = 127,
4: i16 integer16 = 32767,
5: i32 integer32,
6: i64 integer64 = 10000000000,
7: double double_precision,
8: string some_characters,
9: string zomg_unicode,
10: bool what_who,
11: string base64,
12: list<byte> byte_list = [1, 2, 3, ],
13: list<i16> i16_list = [1, 2, 3, ],
14: list<i64> i64_list = [1, 2, 3, ],
}
service Srv {
void voidMethod(),
i32 primitiveMethod(),
OneOfEach structMethod(),
void methodWithDefaultArgs(1: i32 something = 2, ),
oneway void onewayMethod(),
void exceptionMethod() throws (1: ExceptionWithAMap a, 2: ExceptionWithAMap b, ),
}
service ChildSrv extends Srv {
i32 childMethod(-1: i32 param1, ),
}
`);
}

View file

@ -0,0 +1,497 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.codegen.processor;
import std.algorithm : find;
import std.array : empty, front;
import std.conv : to;
import std.traits : ParameterTypeTuple, ReturnType, Unqual;
import std.typetuple : allSatisfy, TypeTuple;
import std.variant : Variant;
import thrift.base;
import thrift.codegen.base;
import thrift.internal.codegen;
import thrift.internal.ctfe;
import thrift.protocol.base;
import thrift.protocol.processor;
/**
* Service processor for Interface, which implements TProcessor by
* synchronously forwarding requests for the service methods to a handler
* implementing Interface.
*
* The generated class implements TProcessor and additionally allows a
* TProcessorEventHandler to be specified via the public eventHandler property.
* The constructor takes a single argument of type Interface, which is the
* handler to forward the requests to:
* ---
* this(Interface iface);
* TProcessorEventHandler eventHandler;
* ---
*
* If Interface is derived from another service BaseInterface, this class is
* also derived from TServiceProcessor!BaseInterface.
*
* The optional Protocols template tuple parameter can be used to specify
* one or more TProtocol implementations to specifically generate code for. If
* the actual types of the protocols passed to process() at runtime match one
* of the items from the list, the optimized code paths are taken, otherwise,
* a generic TProtocol version is used as fallback. For cases where the input
* and output protocols differ, TProtocolPair!(InputProtocol, OutputProtocol)
* can be used in the Protocols list:
* ---
* interface FooService { void foo(); }
* class FooImpl { override void foo {} }
*
* // Provides fast path if TBinaryProtocol!TBufferedTransport is used for
* // both input and output:
* alias TServiceProcessor!(FooService, TBinaryProtocol!TBufferedTransport)
* BinaryProcessor;
*
* auto proc = new BinaryProcessor(new FooImpl());
*
* // Low overhead.
* proc.process(tBinaryProtocol(tBufferTransport(someSocket)));
*
* // Not in the specialization list higher overhead.
* proc.process(tBinaryProtocol(tFramedTransport(someSocket)));
*
* // Same as above, but optimized for the Compact protocol backed by a
* // TPipedTransport for input and a TBufferedTransport for output.
* alias TServiceProcessor!(FooService, TProtocolPair!(
* TCompactProtocol!TPipedTransport, TCompactProtocol!TBufferedTransport)
* ) MixedProcessor;
* ---
*/
template TServiceProcessor(Interface, Protocols...) if (
isService!Interface && allSatisfy!(isTProtocolOrPair, Protocols)
) {
mixin({
static if (is(Interface BaseInterfaces == super) && BaseInterfaces.length > 0) {
static assert(BaseInterfaces.length == 1,
"Services cannot be derived from more than one parent.");
string code = "class TServiceProcessor : " ~
"TServiceProcessor!(BaseService!Interface) {\n";
code ~= "private Interface iface_;\n";
string constructorCode = "this(Interface iface) {\n";
constructorCode ~= "super(iface);\n";
constructorCode ~= "iface_ = iface;\n";
} else {
string code = "class TServiceProcessor : TProcessor {";
code ~= q{
override bool process(TProtocol iprot, TProtocol oprot,
Variant context = Variant()
) {
auto msg = iprot.readMessageBegin();
void writeException(TApplicationException e) {
oprot.writeMessageBegin(TMessage(msg.name, TMessageType.EXCEPTION,
msg.seqid));
e.write(oprot);
oprot.writeMessageEnd();
oprot.transport.writeEnd();
oprot.transport.flush();
}
if (msg.type != TMessageType.CALL && msg.type != TMessageType.ONEWAY) {
skip(iprot, TType.STRUCT);
iprot.readMessageEnd();
iprot.transport.readEnd();
writeException(new TApplicationException(
TApplicationException.Type.INVALID_MESSAGE_TYPE));
return false;
}
auto dg = msg.name in processMap_;
if (!dg) {
skip(iprot, TType.STRUCT);
iprot.readMessageEnd();
iprot.transport.readEnd();
writeException(new TApplicationException("Invalid method name: '" ~
msg.name ~ "'.", TApplicationException.Type.INVALID_MESSAGE_TYPE));
return false;
}
(*dg)(msg.seqid, iprot, oprot, context);
return true;
}
TProcessorEventHandler eventHandler;
alias void delegate(int, TProtocol, TProtocol, Variant) ProcessFunc;
protected ProcessFunc[string] processMap_;
private Interface iface_;
};
string constructorCode = "this(Interface iface) {\n";
constructorCode ~= "iface_ = iface;\n";
}
// Generate the handling code for each method, consisting of the dispatch
// function, registering it in the constructor, and the actual templated
// handler function.
foreach (methodName;
FilterMethodNames!(Interface, __traits(derivedMembers, Interface))
) {
// Register the processing function in the constructor.
immutable procFuncName = "process_" ~ methodName;
immutable dispatchFuncName = procFuncName ~ "_protocolDispatch";
constructorCode ~= "processMap_[`" ~ methodName ~ "`] = &" ~
dispatchFuncName ~ ";\n";
bool methodMetaFound;
TMethodMeta methodMeta;
static if (is(typeof(Interface.methodMeta) : TMethodMeta[])) {
enum meta = find!`a.name == b`(Interface.methodMeta, methodName);
if (!meta.empty) {
methodMetaFound = true;
methodMeta = meta.front;
}
}
// The dispatch function to call the specialized handler functions. We
// test the protocols if they can be converted to one of the passed
// protocol types, and if not, fall back to the generic TProtocol
// version of the processing function.
code ~= "void " ~ dispatchFuncName ~
"(int seqid, TProtocol iprot, TProtocol oprot, Variant context) {\n";
code ~= "foreach (Protocol; TypeTuple!(Protocols, TProtocol)) {\n";
code ~= q{
static if (is(Protocol _ : TProtocolPair!(I, O), I, O)) {
alias I IProt;
alias O OProt;
} else {
alias Protocol IProt;
alias Protocol OProt;
}
auto castedIProt = cast(IProt)iprot;
auto castedOProt = cast(OProt)oprot;
};
code ~= "if (castedIProt && castedOProt) {\n";
code ~= procFuncName ~
"!(IProt, OProt)(seqid, castedIProt, castedOProt, context);\n";
code ~= "return;\n";
code ~= "}\n";
code ~= "}\n";
code ~= "throw new TException(`Internal error: Null iprot/oprot " ~
"passed to processor protocol dispatch function.`);\n";
code ~= "}\n";
// The actual handler function, templated on the input and output
// protocol types.
code ~= "void " ~ procFuncName ~ "(IProt, OProt)(int seqid, IProt " ~
"iprot, OProt oprot, Variant connectionContext) " ~
"if (isTProtocol!IProt && isTProtocol!OProt) {\n";
code ~= "TArgsStruct!(Interface, `" ~ methodName ~ "`) args;\n";
// Store the (qualified) method name in a manifest constant to avoid
// having to litter the code below with lots of string manipulation.
code ~= "enum methodName = `" ~ methodName ~ "`;\n";
code ~= q{
enum qName = Interface.stringof ~ "." ~ methodName;
Variant callContext;
if (eventHandler) {
callContext = eventHandler.createContext(qName, connectionContext);
}
scope (exit) {
if (eventHandler) {
eventHandler.deleteContext(callContext, qName);
}
}
if (eventHandler) eventHandler.preRead(callContext, qName);
args.read(iprot);
iprot.readMessageEnd();
iprot.transport.readEnd();
if (eventHandler) eventHandler.postRead(callContext, qName);
};
code ~= "TResultStruct!(Interface, `" ~ methodName ~ "`) result;\n";
code ~= "try {\n";
// Generate the parameter list to pass to the called iface function.
string[] paramList;
foreach (i, _; ParameterTypeTuple!(mixin("Interface." ~ methodName))) {
string paramName;
if (methodMetaFound && i < methodMeta.params.length) {
paramName = methodMeta.params[i].name;
} else {
paramName = "param" ~ to!string(i + 1);
}
paramList ~= "args." ~ paramName;
}
immutable call = "iface_." ~ methodName ~ "(" ~ ctfeJoin(paramList) ~ ")";
if (is(ReturnType!(mixin("Interface." ~ methodName)) == void)) {
code ~= call ~ ";\n";
} else {
code ~= "result.set!`success`(" ~ call ~ ");\n";
}
// If this is not a oneway method, generate the receiving code.
if (!methodMetaFound || methodMeta.type != TMethodType.ONEWAY) {
if (methodMetaFound) {
foreach (e; methodMeta.exceptions) {
code ~= "} catch (Interface." ~ e.type ~ " " ~ e.name ~ ") {\n";
code ~= "result.set!`" ~ e.name ~ "`(" ~ e.name ~ ");\n";
}
}
code ~= "}\n";
code ~= q{
catch (Exception e) {
if (eventHandler) {
eventHandler.handlerError(callContext, qName, e);
}
auto x = new TApplicationException(to!string(e));
oprot.writeMessageBegin(
TMessage(methodName, TMessageType.EXCEPTION, seqid));
x.write(oprot);
oprot.writeMessageEnd();
oprot.transport.writeEnd();
oprot.transport.flush();
return;
}
if (eventHandler) eventHandler.preWrite(callContext, qName);
oprot.writeMessageBegin(TMessage(methodName,
TMessageType.REPLY, seqid));
result.write(oprot);
oprot.writeMessageEnd();
oprot.transport.writeEnd();
oprot.transport.flush();
if (eventHandler) eventHandler.postWrite(callContext, qName);
};
} else {
// For oneway methods, we obviously cannot notify the client of any
// exceptions, just call the event handler if one is set.
code ~= "}\n";
code ~= q{
catch (Exception e) {
if (eventHandler) {
eventHandler.handlerError(callContext, qName, e);
}
return;
}
if (eventHandler) eventHandler.onewayComplete(callContext, qName);
};
}
code ~= "}\n";
}
code ~= constructorCode ~ "}\n";
code ~= "}\n";
return code;
}());
}
/**
* A struct representing the arguments of a Thrift method call.
*
* There should usually be no reason to use this directly without the help of
* TServiceProcessor, but it is documented publicly to help debugging in case
* of CTFE errors.
*
* Consider this example:
* ---
* interface Foo {
* int bar(string a, bool b);
*
* enum methodMeta = [
* TMethodMeta("bar", [TParamMeta("a", 1), TParamMeta("b", 2)])
* ];
* }
*
* alias TArgsStruct!(Foo, "bar") FooBarArgs;
* ---
*
* The definition of FooBarArgs is equivalent to:
* ---
* struct FooBarArgs {
* string a;
* bool b;
*
* mixin TStructHelpers!([TFieldMeta("a", 1, TReq.OPT_IN_REQ_OUT),
* TFieldMeta("b", 2, TReq.OPT_IN_REQ_OUT)]);
* }
* ---
*
* If the TVerboseCodegen version is defined, a warning message is issued at
* compilation if no TMethodMeta for Interface.methodName is found.
*/
template TArgsStruct(Interface, string methodName) {
static assert(is(typeof(mixin("Interface." ~ methodName))),
"Could not find method '" ~ methodName ~ "' in '" ~ Interface.stringof ~ "'.");
mixin({
bool methodMetaFound;
TMethodMeta methodMeta;
static if (is(typeof(Interface.methodMeta) : TMethodMeta[])) {
auto meta = find!`a.name == b`(Interface.methodMeta, methodName);
if (!meta.empty) {
methodMetaFound = true;
methodMeta = meta.front;
}
}
string memberCode;
string[] fieldMetaCodes;
foreach (i, _; ParameterTypeTuple!(mixin("Interface." ~ methodName))) {
// If we have no meta information, just use param1, param2, etc. as
// field names, it shouldn't really matter anyway. 1-based »indexing«
// is used to match the common scheme in the Thrift world.
string memberId;
string memberName;
if (methodMetaFound && i < methodMeta.params.length) {
memberId = to!string(methodMeta.params[i].id);
memberName = methodMeta.params[i].name;
} else {
memberId = to!string(i + 1);
memberName = "param" ~ to!string(i + 1);
}
// Unqual!() is needed to generate mutable fields for ref const()
// struct parameters.
memberCode ~= "Unqual!(ParameterTypeTuple!(Interface." ~ methodName ~
")[" ~ to!string(i) ~ "])" ~ memberName ~ ";\n";
fieldMetaCodes ~= "TFieldMeta(`" ~ memberName ~ "`, " ~ memberId ~
", TReq.OPT_IN_REQ_OUT)";
}
string code = "struct TArgsStruct {\n";
code ~= memberCode;
version (TVerboseCodegen) {
if (!methodMetaFound &&
ParameterTypeTuple!(mixin("Interface." ~ methodName)).length > 0)
{
code ~= "pragma(msg, `[thrift.codegen.processor.TArgsStruct] Warning: No " ~
"meta information for method '" ~ methodName ~ "' in service '" ~
Interface.stringof ~ "' found.`);\n";
}
}
immutable fieldMetaCode =
fieldMetaCodes.empty ? "" : "[" ~ ctfeJoin(fieldMetaCodes) ~ "]";
code ~= "mixin TStructHelpers!(" ~ fieldMetaCode ~ ");\n";
code ~= "}\n";
return code;
}());
}
/**
* A struct representing the result of a Thrift method call.
*
* It contains a field called "success" for the return value of the function
* (with id 0), and additional fields for the exceptions declared for the
* method, if any.
*
* There should usually be no reason to use this directly without the help of
* TServiceProcessor, but it is documented publicly to help debugging in case
* of CTFE errors.
*
* Consider the following example:
* ---
* interface Foo {
* int bar(string a);
*
* alias .FooException FooException;
*
* enum methodMeta = [
* TMethodMeta("bar",
* [TParamMeta("a", 1)],
* [TExceptionMeta("fooe", 1, "FooException")]
* )
* ];
* }
* alias TResultStruct!(Foo, "bar") FooBarResult;
* ---
*
* The definition of FooBarResult is equivalent to:
* ---
* struct FooBarResult {
* int success;
* FooException fooe;
*
* mixin(TStructHelpers!([TFieldMeta("success", 0, TReq.OPTIONAL),
* TFieldMeta("fooe", 1, TReq.OPTIONAL)]));
* }
* ---
*
* If the TVerboseCodegen version is defined, a warning message is issued at
* compilation if no TMethodMeta for Interface.methodName is found.
*/
template TResultStruct(Interface, string methodName) {
static assert(is(typeof(mixin("Interface." ~ methodName))),
"Could not find method '" ~ methodName ~ "' in '" ~ Interface.stringof ~ "'.");
mixin({
string code = "struct TResultStruct {\n";
string[] fieldMetaCodes;
static if (!is(ReturnType!(mixin("Interface." ~ methodName)) == void)) {
code ~= "ReturnType!(Interface." ~ methodName ~ ") success;\n";
fieldMetaCodes ~= "TFieldMeta(`success`, 0, TReq.OPTIONAL)";
}
bool methodMetaFound;
static if (is(typeof(Interface.methodMeta) : TMethodMeta[])) {
auto meta = find!`a.name == b`(Interface.methodMeta, methodName);
if (!meta.empty) {
foreach (e; meta.front.exceptions) {
code ~= "Interface." ~ e.type ~ " " ~ e.name ~ ";\n";
fieldMetaCodes ~= "TFieldMeta(`" ~ e.name ~ "`, " ~ to!string(e.id) ~
", TReq.OPTIONAL)";
}
methodMetaFound = true;
}
}
version (TVerboseCodegen) {
if (!methodMetaFound &&
ParameterTypeTuple!(mixin("Interface." ~ methodName)).length > 0)
{
code ~= "pragma(msg, `[thrift.codegen.processor.TResultStruct] Warning: No " ~
"meta information for method '" ~ methodName ~ "' in service '" ~
Interface.stringof ~ "' found.`);\n";
}
}
immutable fieldMetaCode =
fieldMetaCodes.empty ? "" : "[" ~ ctfeJoin(fieldMetaCodes) ~ "]";
code ~= "mixin TStructHelpers!(" ~ fieldMetaCode ~ ");\n";
code ~= "}\n";
return code;
}());
}

View file

@ -0,0 +1,33 @@
Ddoc
<h2>Package overview</h2>
<dl>
<dt>$(D_CODE thrift.async)</dt>
<dd>Support infrastructure for handling client-side asynchronous operations using non-blocking I/O and coroutines.</dd>
<dt>$(D_CODE thrift.codegen)</dt>
<dd>
<p>Templates used for generating Thrift clients/processors from regular D struct and interface definitions.</p>
<p><strong>Note:</strong> Several artifacts in these modules have options for specifying the exact protocol types used. In this case, the amount of virtual calls can be greatly reduced and as a result, the code also can be optimized better. If performance is not a concern or the actual protocol type is not known at compile time, these parameters can just be left at their defaults.
</p>
</dd>
<dt>$(D_CODE thrift.internal)</dt>
<dd>Internal helper modules used by the Thrift library. This package is not part of the public API, and no stability guarantees are given whatsoever.</dd>
<dt>$(D_CODE thrift.protocol)</dt>
<dd>The Thrift protocol implemtations which specify how to pass messages over a TTransport.</dd>
<dt>$(D_CODE thrift.server)</dt>
<dd>Generic Thrift server implementations handling clients over a TTransport interface and forwarding requests to a TProcessor (which is in turn usually provided by thrift.codegen).</dd>
<dt>$(D_CODE thrift.transport)</dt>
<dd>The TTransport data source/sink interface used in the Thrift library and its imiplementations.</dd>
<dt>$(D_CODE thrift.util)</dt>
<dd>General-purpose utility modules not specific to Thrift, part of the public API.</dd>
</dl>
Macros:
TITLE = Thrift D Software Library

View file

@ -0,0 +1,55 @@
/**
* Contains a modified version of std.algorithm.remove that doesn't take an
* alias parameter to avoid DMD @@BUG6395@@.
*/
module thrift.internal.algorithm;
import std.algorithm : move;
import std.exception;
import std.functional;
import std.range;
import std.traits;
enum SwapStrategy
{
unstable,
semistable,
stable,
}
Range removeEqual(SwapStrategy s = SwapStrategy.stable, Range, E)(Range range, E e)
if (isBidirectionalRange!Range)
{
auto result = range;
static if (s != SwapStrategy.stable)
{
for (;!range.empty;)
{
if (range.front !is e)
{
range.popFront;
continue;
}
move(range.back, range.front);
range.popBack;
result.popBack;
}
}
else
{
auto tgt = range;
for (; !range.empty; range.popFront)
{
if (range.front is e)
{
// yank this guy
result.popBack;
continue;
}
// keep this guy
move(range.front, tgt.front);
tgt.popFront;
}
}
return result;
}

View file

@ -0,0 +1,451 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.internal.codegen;
import std.algorithm : canFind;
import std.traits : InterfacesTuple, isSomeFunction, isSomeString;
import std.typetuple : staticIndexOf, staticMap, NoDuplicates, TypeTuple;
import thrift.codegen.base;
/**
* Removes all type qualifiers from T.
*
* In contrast to std.traits.Unqual, FullyUnqual also removes qualifiers from
* array elements (e.g. immutable(byte[]) -> byte[], not immutable(byte)[]),
* excluding strings (string isn't reduced to char[]).
*/
template FullyUnqual(T) {
static if (is(T _ == const(U), U)) {
alias FullyUnqual!U FullyUnqual;
} else static if (is(T _ == immutable(U), U)) {
alias FullyUnqual!U FullyUnqual;
} else static if (is(T _ == shared(U), U)) {
alias FullyUnqual!U FullyUnqual;
} else static if (is(T _ == U[], U) && !isSomeString!T) {
alias FullyUnqual!(U)[] FullyUnqual;
} else static if (is(T _ == V[K], K, V)) {
alias FullyUnqual!(V)[FullyUnqual!K] FullyUnqual;
} else {
alias T FullyUnqual;
}
}
/**
* true if null can be assigned to the passed type, false if not.
*/
template isNullable(T) {
enum isNullable = __traits(compiles, { T t = null; });
}
template isStruct(T) {
enum isStruct = is(T == struct);
}
template isException(T) {
enum isException = is(T : Exception);
}
template isEnum(T) {
enum isEnum = is(T == enum);
}
/**
* Aliases itself to T.name.
*/
template GetMember(T, string name) {
mixin("alias T." ~ name ~ " GetMember;");
}
/**
* Aliases itself to typeof(symbol).
*/
template TypeOf(alias symbol) {
alias typeof(symbol) TypeOf;
}
/**
* Aliases itself to the type of the T member called name.
*/
alias Compose!(TypeOf, GetMember) MemberType;
/**
* Returns the field metadata array for T if any, or an empty array otherwise.
*/
template getFieldMeta(T) if (isStruct!T || isException!T) {
static if (is(typeof(T.fieldMeta) == TFieldMeta[])) {
enum getFieldMeta = T.fieldMeta;
} else {
enum TFieldMeta[] getFieldMeta = [];
}
}
/**
* Merges the field metadata array for D with the passed array.
*/
template mergeFieldMeta(T, alias fieldMetaData = cast(TFieldMeta[])null) {
// Note: We don't use getFieldMeta here to avoid bug if it is instantiated
// from TIsSetFlags, see comment there.
static if (is(typeof(T.fieldMeta) == TFieldMeta[])) {
enum mergeFieldMeta = T.fieldMeta ~ fieldMetaData;
} else {
enum TFieldMeta[] mergeFieldMeta = fieldMetaData;
}
}
/**
* Returns the field requirement level for T.name.
*/
template memberReq(T, string name, alias fieldMetaData = cast(TFieldMeta[])null) {
enum memberReq = memberReqImpl!(T, name, fieldMetaData).result;
}
private {
import std.algorithm : find;
// DMD @@BUG@@: Missing import leads to failing build without error
// message in unittest/debug/thrift/codegen/async_client.
import std.array : empty, front;
template memberReqImpl(T, string name, alias fieldMetaData) {
enum meta = find!`a.name == b`(mergeFieldMeta!(T, fieldMetaData), name);
static if (meta.empty || meta.front.req == TReq.AUTO) {
static if (isNullable!(MemberType!(T, name))) {
enum result = TReq.OPTIONAL;
} else {
enum result = TReq.REQUIRED;
}
} else {
enum result = meta.front.req;
}
}
}
template notIgnored(T, string name, alias fieldMetaData = cast(TFieldMeta[])null) {
enum notIgnored = memberReq!(T, name, fieldMetaData) != TReq.IGNORE;
}
/**
* Returns the method metadata array for T if any, or an empty array otherwise.
*/
template getMethodMeta(T) if (isService!T) {
static if (is(typeof(T.methodMeta) == TMethodMeta[])) {
enum getMethodMeta = T.methodMeta;
} else {
enum TMethodMeta[] getMethodMeta = [];
}
}
/**
* true if T.name is a member variable. Exceptions include methods, static
* members, artifacts like package aliases,
*/
template isValueMember(T, string name) {
static if (!is(MemberType!(T, name))) {
enum isValueMember = false;
} else static if (
is(MemberType!(T, name) == void) ||
isSomeFunction!(MemberType!(T, name)) ||
__traits(compiles, { return mixin("T." ~ name); }())
) {
enum isValueMember = false;
} else {
enum isValueMember = true;
}
}
/**
* Returns a tuple containing the names of the fields of T, not including
* inherited fields. If a member is marked as TReq.IGNORE, it is not included
* as well.
*/
template FieldNames(T, alias fieldMetaData = cast(TFieldMeta[])null) {
alias StaticFilter!(
All!(
doesNotReadMembers,
PApply!(isValueMember, T),
PApply!(notIgnored, T, PApplySkip, fieldMetaData)
),
__traits(derivedMembers, T)
) FieldNames;
}
/*
* true if the passed member name is not a method generated by the
* TStructHelpers template that in its implementations queries the struct
* members.
*
* Kludge used internally to break a cycle caused a DMD forward reference
* regression, see THRIFT-2130.
*/
enum doesNotReadMembers(string name) = !["opEquals", "thriftOpEqualsImpl",
"toString", "thriftToStringImpl"].canFind(name);
template derivedMembers(T) {
alias TypeTuple!(__traits(derivedMembers, T)) derivedMembers;
}
template AllMemberMethodNames(T) if (isService!T) {
alias NoDuplicates!(
FilterMethodNames!(
T,
staticMap!(
derivedMembers,
TypeTuple!(T, InterfacesTuple!T)
)
)
) AllMemberMethodNames;
}
template FilterMethodNames(T, MemberNames...) {
alias StaticFilter!(
CompilesAndTrue!(
Compose!(isSomeFunction, TypeOf, PApply!(GetMember, T))
),
MemberNames
) FilterMethodNames;
}
/**
* Returns a type tuple containing only the elements of T for which the
* eponymous template predicate pred is true.
*
* Example:
* ---
* alias StaticFilter!(isIntegral, int, string, long, float[]) Filtered;
* static assert(is(Filtered == TypeTuple!(int, long)));
* ---
*/
template StaticFilter(alias pred, T...) {
static if (T.length == 0) {
alias TypeTuple!() StaticFilter;
} else static if (pred!(T[0])) {
alias TypeTuple!(T[0], StaticFilter!(pred, T[1 .. $])) StaticFilter;
} else {
alias StaticFilter!(pred, T[1 .. $]) StaticFilter;
}
}
/**
* Binds the first n arguments of a template to a particular value (where n is
* the number of arguments passed to PApply).
*
* The passed arguments are always applied starting from the left. However,
* the special PApplySkip marker template can be used to indicate that an
* argument should be skipped, so that e.g. the first and third argument
* to a template can be fixed, but the second and remaining arguments would
* still be left undefined.
*
* Skipping a number of parameters, but not providing enough arguments to
* assign all of them during instantiation of the resulting template is an
* error.
*
* Example:
* ---
* struct Foo(T, U, V) {}
* alias PApply!(Foo, int, long) PartialFoo;
* static assert(is(PartialFoo!float == Foo!(int, long, float)));
*
* alias PApply!(Test, int, PApplySkip, float) SkippedTest;
* static assert(is(SkippedTest!long == Test!(int, long, float)));
* ---
*/
template PApply(alias Target, T...) {
template PApply(U...) {
alias Target!(PApplyMergeArgs!(ConfinedTuple!T, U).Result) PApply;
}
}
/// Ditto.
template PApplySkip() {}
private template PApplyMergeArgs(alias Preset, Args...) {
static if (Preset.length == 0) {
alias Args Result;
} else {
enum nextSkip = staticIndexOf!(PApplySkip, Preset.Tuple);
static if (nextSkip == -1) {
alias TypeTuple!(Preset.Tuple, Args) Result;
} else static if (Args.length == 0) {
// Have to use a static if clause instead of putting the condition
// directly into the assert to avoid DMD trying to access Args[0]
// nevertheless below.
static assert(false,
"PArgsSkip encountered, but no argument left to bind.");
} else {
alias TypeTuple!(
Preset.Tuple[0 .. nextSkip],
Args[0],
PApplyMergeArgs!(
ConfinedTuple!(Preset.Tuple[nextSkip + 1 .. $]),
Args[1 .. $]
).Result
) Result;
}
}
}
unittest {
struct Test(T, U, V) {}
alias PApply!(Test, int, long) PartialTest;
static assert(is(PartialTest!float == Test!(int, long, float)));
alias PApply!(Test, int, PApplySkip, float) SkippedTest;
static assert(is(SkippedTest!long == Test!(int, long, float)));
alias PApply!(Test, int, PApplySkip, PApplySkip) TwoSkipped;
static assert(!__traits(compiles, TwoSkipped!long));
}
/**
* Composes a number of templates. The result is a template equivalent to
* all the passed templates evaluated from right to left, akin to the
* mathematical function composition notation: Instantiating Compose!(A, B, C)
* is the same as instantiating A!(B!(C!())).
*
* This is especially useful for creating a template to use with staticMap/
* StaticFilter, as demonstrated below.
*
* Example:
* ---
* template AllMethodNames(T) {
* alias StaticFilter!(
* CompilesAndTrue!(
* Compose!(isSomeFunction, TypeOf, PApply!(GetMember, T))
* ),
* __traits(allMembers, T)
* ) AllMethodNames;
* }
*
* pragma(msg, AllMethodNames!Object);
* ---
*/
template Compose(T...) {
static if (T.length == 0) {
template Compose(U...) {
alias U Compose;
}
} else {
template Compose(U...) {
alias Instantiate!(T[0], Instantiate!(.Compose!(T[1 .. $]), U)) Compose;
}
}
}
/**
* Instantiates the given template with the given list of parameters.
*
* Used to work around syntactic limiations of D with regard to instantiating
* a template from a type tuple (e.g. T[0]!(...) is not valid) or a template
* returning another template (e.g. Foo!(Bar)!(Baz) is not allowed).
*/
template Instantiate(alias Template, Params...) {
alias Template!Params Instantiate;
}
/**
* Combines several template predicates using logical AND, i.e. instantiating
* All!(a, b, c) with parameters P for some templates a, b, c is equivalent to
* a!P && b!P && c!P.
*
* The templates are evaluated from left to right, aborting evaluation in a
* shurt-cut manner if a false result is encountered, in which case the latter
* instantiations do not need to compile.
*/
template All(T...) {
static if (T.length == 0) {
template All(U...) {
enum All = true;
}
} else {
template All(U...) {
static if (Instantiate!(T[0], U)) {
alias Instantiate!(.All!(T[1 .. $]), U) All;
} else {
enum All = false;
}
}
}
}
/**
* Combines several template predicates using logical OR, i.e. instantiating
* Any!(a, b, c) with parameters P for some templates a, b, c is equivalent to
* a!P || b!P || c!P.
*
* The templates are evaluated from left to right, aborting evaluation in a
* shurt-cut manner if a true result is encountered, in which case the latter
* instantiations do not need to compile.
*/
template Any(T...) {
static if (T.length == 0) {
template Any(U...) {
enum Any = false;
}
} else {
template Any(U...) {
static if (Instantiate!(T[0], U)) {
enum Any = true;
} else {
alias Instantiate!(.Any!(T[1 .. $]), U) Any;
}
}
}
}
template ConfinedTuple(T...) {
alias T Tuple;
enum length = T.length;
}
/*
* foreach (Item; Items) {
* List = Operator!(Item, List);
* }
* where Items is a ConfinedTuple and List is a type tuple.
*/
template ForAllWithList(alias Items, alias Operator, List...) if (
is(typeof(Items.length) : size_t)
){
static if (Items.length == 0) {
alias List ForAllWithList;
} else {
alias .ForAllWithList!(
ConfinedTuple!(Items.Tuple[1 .. $]),
Operator,
Operator!(Items.Tuple[0], List)
) ForAllWithList;
}
}
/**
* Wraps the passed template predicate so it returns true if it compiles and
* evaluates to true, false it it doesn't compile or evaluates to false.
*/
template CompilesAndTrue(alias T) {
template CompilesAndTrue(U...) {
static if (is(typeof(T!U) : bool)) {
enum bool CompilesAndTrue = T!U;
} else {
enum bool CompilesAndTrue = false;
}
}
}

View file

@ -0,0 +1,98 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.internal.ctfe;
import std.conv : to;
import std.traits;
/*
* Simple eager join() for strings, std.algorithm.join isn't CTFEable yet.
*/
string ctfeJoin(string[] strings, string separator = ", ") {
string result;
if (strings.length > 0) {
result ~= strings[0];
foreach (s; strings[1..$]) {
result ~= separator ~ s;
}
}
return result;
}
/*
* A very primitive to!string() implementation for floating point numbers that
* is evaluatable at compile time.
*
* There is a wealth of problems associated with the algorithm used (e.g. 5.0
* prints as 4.999, incorrect rounding, etc.), but a better alternative should
* be included with the D standard library instead of implementing it here.
*/
string ctfeToString(T)(T val) if (isFloatingPoint!T) {
if (val is T.nan) return "nan";
if (val is T.infinity) return "inf";
if (val is -T.infinity) return "-inf";
if (val is 0.0) return "0";
if (val is -0.0) return "-0";
auto b = val;
string result;
if (b < 0) {
result ~= '-';
b *= -1;
}
short magnitude;
while (b >= 10) {
++magnitude;
b /= 10;
}
while (b < 1) {
--magnitude;
b *= 10;
}
foreach (i; 0 .. T.dig) {
if (i == 1) result ~= '.';
auto first = cast(ubyte)b;
result ~= to!string(first);
b -= first;
import std.math;
if (b < pow(10.0, i - T.dig)) break;
b *= 10;
}
if (magnitude != 0) result ~= "e" ~ to!string(magnitude);
return result;
}
unittest {
import std.algorithm;
static assert(ctfeToString(double.infinity) == "inf");
static assert(ctfeToString(-double.infinity) == "-inf");
static assert(ctfeToString(double.nan) == "nan");
static assert(ctfeToString(0.0) == "0");
static assert(ctfeToString(-0.0) == "-0");
static assert(ctfeToString(2.5) == "2.5");
static assert(ctfeToString(3.1415).startsWith("3.141"));
static assert(ctfeToString(2e-200) == "2e-200");
}

View file

@ -0,0 +1,75 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*
* Simple helpers for handling typical byte order-related issues.
*/
module thrift.internal.endian;
import core.bitop : bswap;
import std.traits : isIntegral;
union IntBuf(T) {
ubyte[T.sizeof] bytes;
T value;
}
T byteSwap(T)(T t) pure nothrow @trusted if (isIntegral!T) {
static if (T.sizeof == 2) {
return cast(T)((t & 0xff) << 8) | cast(T)((t & 0xff00) >> 8);
} else static if (T.sizeof == 4) {
return cast(T)bswap(cast(uint)t);
} else static if (T.sizeof == 8) {
return cast(T)byteSwap(cast(uint)(t & 0xffffffff)) << 32 |
cast(T)bswap(cast(uint)(t >> 32));
} else static assert(false, "Type of size " ~ to!string(T.sizeof) ~ " not supported.");
}
T doNothing(T)(T val) { return val; }
version (BigEndian) {
alias doNothing hostToNet;
alias doNothing netToHost;
alias byteSwap hostToLe;
alias byteSwap leToHost;
} else {
alias byteSwap hostToNet;
alias byteSwap netToHost;
alias doNothing hostToLe;
alias doNothing leToHost;
}
unittest {
import std.exception;
IntBuf!short s;
s.bytes = [1, 2];
s.value = byteSwap(s.value);
enforce(s.bytes == [2, 1]);
IntBuf!int i;
i.bytes = [1, 2, 3, 4];
i.value = byteSwap(i.value);
enforce(i.bytes == [4, 3, 2, 1]);
IntBuf!long l;
l.bytes = [1, 2, 3, 4, 5, 6, 7, 8];
l.value = byteSwap(l.value);
enforce(l.bytes == [8, 7, 6, 5, 4, 3, 2, 1]);
}

View file

@ -0,0 +1,431 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.internal.resource_pool;
import core.time : Duration, dur, TickDuration;
import std.algorithm : minPos, reduce, remove;
import std.array : array, empty;
import std.exception : enforce;
import std.conv : to;
import std.random : randomCover, rndGen;
import std.range : zip;
import thrift.internal.algorithm : removeEqual;
/**
* A pool of resources, which can be iterated over, and where resources that
* have failed too often can be temporarily disabled.
*
* This class is oblivious to the actual resource type managed.
*/
final class TResourcePool(Resource) {
/**
* Constructs a new instance.
*
* Params:
* resources = The initial members of the pool.
*/
this(Resource[] resources) {
resources_ = resources;
}
/**
* Adds a resource to the pool.
*/
void add(Resource resource) {
resources_ ~= resource;
}
/**
* Removes a resource from the pool.
*
* Returns: Whether the resource could be found in the pool.
*/
bool remove(Resource resource) {
auto oldLength = resources_.length;
resources_ = removeEqual(resources_, resource);
return resources_.length < oldLength;
}
/**
* Returns an »enriched« input range to iterate over the pool members.
*/
static struct Range {
/**
* Whether the range is empty.
*
* This is the case if all members of the pool have been popped (or skipped
* because they were disabled) and TResourcePool.cycle is false, or there
* is no element to return in cycle mode because all have been temporarily
* disabled.
*/
bool empty() @property {
// If no resources are in the pool, the range will never become non-empty.
if (resources_.empty) return true;
// If we already got the next resource in the cache, it doesn't matter
// whether there are more.
if (cached_) return false;
size_t examineCount;
if (parent_.cycle) {
// We want to check all the resources, but not iterate more than once
// to avoid spinning in a loop if nothing is available.
examineCount = resources_.length;
} else {
// When not in cycle mode, we just iterate the list exactly once. If all
// items have been consumed, the interval below is empty.
examineCount = resources_.length - nextIndex_;
}
foreach (i; 0 .. examineCount) {
auto r = resources_[(nextIndex_ + i) % resources_.length];
auto fi = r in parent_.faultInfos_;
if (fi && fi.resetTime != fi.resetTime.init) {
if (fi.resetTime < parent_.getCurrentTick_()) {
// The timeout expired, remove the resource from the list and go
// ahead trying it.
parent_.faultInfos_.remove(r);
} else {
// The timeout didn't expire yet, try the next resource.
continue;
}
}
cache_ = r;
cached_ = true;
nextIndex_ = nextIndex_ + i + 1;
return false;
}
// If we get here, all resources are currently inactive or the non-cycle
// pool has been exhausted, so there is nothing we can do.
nextIndex_ = nextIndex_ + examineCount;
return true;
}
/**
* Returns the first resource in the range.
*/
Resource front() @property {
enforce(!empty);
return cache_;
}
/**
* Removes the first resource from the range.
*
* Usually, this is combined with a call to TResourcePool.recordSuccess()
* or recordFault().
*/
void popFront() {
enforce(!empty);
cached_ = false;
}
/**
* Returns whether the range will become non-empty at some point in the
* future, and provides additional information when this will happen and
* what will be the next resource.
*
* Makes only sense to call on empty ranges.
*
* Params:
* next = The next resource that will become available.
* waitTime = The duration until that resource will become available.
*/
bool willBecomeNonempty(out Resource next, out Duration waitTime) {
// If no resources are in the pool, the range will never become non-empty.
if (resources_.empty) return false;
// If cycle mode is not enabled, a range never becomes non-empty after
// being empty once, because all the elements have already been
// used/skipped in order to become empty.
if (!parent_.cycle) return false;
auto fi = parent_.faultInfos_;
auto nextPair = minPos!"a[1].resetTime < b[1].resetTime"(
zip(fi.keys, fi.values)
).front;
next = nextPair[0];
waitTime = to!Duration(nextPair[1].resetTime - parent_.getCurrentTick_());
return true;
}
private:
this(TResourcePool parent, Resource[] resources) {
parent_ = parent;
resources_ = resources;
}
TResourcePool parent_;
/// All available resources. We keep a copy of it as to not get confused
/// when resources are added to/removed from the parent pool.
Resource[] resources_;
/// After we have determined the next element in empty(), we store it here.
Resource cache_;
/// Whether there is currently something in the cache.
bool cached_;
/// The index to start searching from at the next call to empty().
size_t nextIndex_;
}
/// Ditto
Range opSlice() {
auto res = resources_;
if (permute) {
res = array(randomCover(res, rndGen));
}
return Range(this, res);
}
/**
* Records a success for an operation on the given resource, cancelling a
* fault streak, if any.
*/
void recordSuccess(Resource resource) {
if (resource in faultInfos_) {
faultInfos_.remove(resource);
}
}
/**
* Records a fault for the given resource.
*
* If a resource fails consecutively for more than faultDisableCount times,
* it is temporarily disabled (no longer considered) until
* faultDisableDuration has passed.
*/
void recordFault(Resource resource) {
auto fi = resource in faultInfos_;
if (!fi) {
faultInfos_[resource] = FaultInfo();
fi = resource in faultInfos_;
}
++fi.count;
if (fi.count >= faultDisableCount) {
// If the resource has hit the fault count limit, disable it for
// specified duration.
fi.resetTime = getCurrentTick_() + cast(TickDuration)faultDisableDuration;
}
}
/**
* Whether to randomly permute the order of the resources in the pool when
* taking a range using opSlice().
*
* This can be used e.g. as a simple form of load balancing.
*/
bool permute = true;
/**
* Whether to keep iterating over the pool members after all have been
* returned/have failed once.
*/
bool cycle = false;
/**
* The number of consecutive faults after which a resource is disabled until
* faultDisableDuration has passed. Zero to never disable resources.
*
* Defaults to zero.
*/
ushort faultDisableCount = 0;
/**
* The duration for which a resource is no longer considered after it has
* failed too often.
*
* Defaults to one second.
*/
Duration faultDisableDuration = dur!"seconds"(1);
private:
Resource[] resources_;
FaultInfo[Resource] faultInfos_;
/// Function to get the current timestamp from some monotonic system clock.
///
/// This is overridable to be able to write timing-insensitive unit tests.
/// The extra indirection should not matter much performance-wise compared to
/// the actual system call, and by its very nature thisshould not be on a hot
/// path anyway.
typeof(&TickDuration.currSystemTick) getCurrentTick_ =
&TickDuration.currSystemTick;
}
private {
struct FaultInfo {
ushort count;
TickDuration resetTime;
}
}
unittest {
auto pool = new TResourcePool!Object([]);
enforce(pool[].empty);
Object dummyRes;
Duration dummyDur;
enforce(!pool[].willBecomeNonempty(dummyRes, dummyDur));
}
unittest {
import std.datetime;
import thrift.base;
auto a = new Object;
auto b = new Object;
auto c = new Object;
auto objs = [a, b, c];
auto pool = new TResourcePool!Object(objs);
pool.permute = false;
static Duration fakeClock;
pool.getCurrentTick_ = () => cast(TickDuration)fakeClock;
Object dummyRes = void;
Duration dummyDur = void;
{
auto r = pool[];
foreach (i, o; objs) {
enforce(!r.empty);
enforce(r.front == o);
r.popFront();
}
enforce(r.empty);
enforce(!r.willBecomeNonempty(dummyRes, dummyDur));
}
{
pool.faultDisableCount = 2;
enforce(pool[].front == a);
pool.recordFault(a);
enforce(pool[].front == a);
pool.recordSuccess(a);
enforce(pool[].front == a);
pool.recordFault(a);
enforce(pool[].front == a);
pool.recordFault(a);
auto r = pool[];
enforce(r.front == b);
r.popFront();
enforce(r.front == c);
r.popFront();
enforce(r.empty);
enforce(!r.willBecomeNonempty(dummyRes, dummyDur));
fakeClock += 2.seconds;
// Not in cycle mode, has to be still empty after the timeouts expired.
enforce(r.empty);
enforce(!r.willBecomeNonempty(dummyRes, dummyDur));
foreach (o; objs) pool.recordSuccess(o);
}
{
pool.faultDisableCount = 1;
pool.recordFault(a);
pool.recordFault(b);
pool.recordFault(c);
auto r = pool[];
enforce(r.empty);
enforce(!r.willBecomeNonempty(dummyRes, dummyDur));
foreach (o; objs) pool.recordSuccess(o);
}
pool.cycle = true;
{
auto r = pool[];
foreach (o; objs ~ objs) {
enforce(!r.empty);
enforce(r.front == o);
r.popFront();
}
}
{
pool.faultDisableCount = 2;
enforce(pool[].front == a);
pool.recordFault(a);
enforce(pool[].front == a);
pool.recordSuccess(a);
enforce(pool[].front == a);
pool.recordFault(a);
enforce(pool[].front == a);
pool.recordFault(a);
auto r = pool[];
enforce(r.front == b);
r.popFront();
enforce(r.front == c);
r.popFront();
enforce(r.front == b);
fakeClock += 2.seconds;
r.popFront();
enforce(r.front == c);
r.popFront();
enforce(r.front == a);
enforce(pool[].front == a);
foreach (o; objs) pool.recordSuccess(o);
}
{
pool.faultDisableCount = 1;
pool.recordFault(a);
fakeClock += 1.msecs;
pool.recordFault(b);
fakeClock += 1.msecs;
pool.recordFault(c);
auto r = pool[];
enforce(r.empty);
// Make sure willBecomeNonempty gets the order right.
enforce(r.willBecomeNonempty(dummyRes, dummyDur));
enforce(dummyRes == a);
enforce(dummyDur > Duration.zero);
foreach (o; objs) pool.recordSuccess(o);
}
}

View file

@ -0,0 +1,96 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/**
* Abstractions over OS-dependent socket functionality.
*/
module thrift.internal.socket;
import std.conv : to;
// FreeBSD and OS X return -1 and set ECONNRESET if socket was closed by
// the other side, we need to check for that before throwing an exception.
version (FreeBSD) {
enum connresetOnPeerShutdown = true;
} else version (OSX) {
enum connresetOnPeerShutdown = true;
} else {
enum connresetOnPeerShutdown = false;
}
version (Win32) {
import std.c.windows.winsock : WSAGetLastError, WSAEINTR, WSAEWOULDBLOCK;
import std.windows.syserror : sysErrorString;
// These are unfortunately not defined in std.c.windows.winsock, see
// http://msdn.microsoft.com/en-us/library/ms740668.aspx.
enum WSAECONNRESET = 10054;
enum WSAENOTCONN = 10057;
enum WSAETIMEDOUT = 10060;
} else {
import core.stdc.errno : errno, EAGAIN, ECONNRESET, EINPROGRESS, EINTR,
ENOTCONN, EPIPE;
import core.stdc.string : strerror;
}
/*
* CONNECT_INPROGRESS_ERRNO: set by connect() for non-blocking sockets if the
* connection could not be immediately established.
* INTERRUPTED_ERRNO: set when blocking system calls are interrupted by
* signals or similar.
* TIMEOUT_ERRNO: set when a socket timeout has been exceeded.
* WOULD_BLOCK_ERRNO: set when send/recv would block on non-blocking sockets.
*
* isSocetCloseErrno(errno): returns true if errno indicates that the socket
* is logically in closed state now.
*/
version (Win32) {
alias WSAGetLastError getSocketErrno;
enum CONNECT_INPROGRESS_ERRNO = WSAEWOULDBLOCK;
enum INTERRUPTED_ERRNO = WSAEINTR;
enum TIMEOUT_ERRNO = WSAETIMEDOUT;
enum WOULD_BLOCK_ERRNO = WSAEWOULDBLOCK;
bool isSocketCloseErrno(typeof(getSocketErrno()) errno) {
return (errno == WSAECONNRESET || errno == WSAENOTCONN);
}
} else {
alias errno getSocketErrno;
enum CONNECT_INPROGRESS_ERRNO = EINPROGRESS;
enum INTERRUPTED_ERRNO = EINTR;
enum WOULD_BLOCK_ERRNO = EAGAIN;
// TODO: The C++ TSocket implementation mentions that EAGAIN can also be
// set (undocumentedly) in out of resource conditions; it would be a good
// idea to contact the original authors of the C++ code for details and adapt
// the code accordingly.
enum TIMEOUT_ERRNO = EAGAIN;
bool isSocketCloseErrno(typeof(getSocketErrno()) errno) {
return (errno == EPIPE || errno == ECONNRESET || errno == ENOTCONN);
}
}
string socketErrnoString(uint errno) {
version (Win32) {
return sysErrorString(errno);
} else {
return to!string(strerror(errno));
}
}

View file

@ -0,0 +1,240 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.internal.ssl;
import core.memory : GC;
import core.stdc.config;
import core.stdc.errno : errno;
import core.stdc.string : strerror;
import deimos.openssl.err;
import deimos.openssl.ssl;
import deimos.openssl.x509v3;
import std.array : empty, appender;
import std.conv : to;
import std.socket : Address;
import thrift.transport.ssl;
/**
* Checks if the peer is authorized after the SSL handshake has been
* completed on the given conncetion and throws an TSSLException if not.
*
* Params:
* ssl = The SSL connection to check.
* accessManager = The access manager to check the peer againts.
* peerAddress = The (IP) address of the peer.
* hostName = The host name of the peer.
*/
void authorize(SSL* ssl, TAccessManager accessManager,
Address peerAddress, lazy string hostName
) {
alias TAccessManager.Decision Decision;
auto rc = SSL_get_verify_result(ssl);
if (rc != X509_V_OK) {
throw new TSSLException("SSL_get_verify_result(): " ~
to!string(X509_verify_cert_error_string(rc)));
}
auto cert = SSL_get_peer_certificate(ssl);
if (cert is null) {
// Certificate is not present.
if (SSL_get_verify_mode(ssl) & SSL_VERIFY_FAIL_IF_NO_PEER_CERT) {
throw new TSSLException(
"Authorize: Required certificate not present.");
}
// If we don't have an access manager set, we don't intend to authorize
// the client, so everything's fine.
if (accessManager) {
throw new TSSLException(
"Authorize: Certificate required for authorization.");
}
return;
}
if (accessManager is null) {
// No access manager set, can return immediately as the cert is valid
// and all peers are authorized.
X509_free(cert);
return;
}
// both certificate and access manager are present
auto decision = accessManager.verify(peerAddress);
if (decision != Decision.SKIP) {
X509_free(cert);
if (decision != Decision.ALLOW) {
throw new TSSLException("Authorize: Access denied based on remote IP.");
}
return;
}
// Check subjectAltName(s), if present.
auto alternatives = cast(STACK_OF!(GENERAL_NAME)*)
X509_get_ext_d2i(cert, NID_subject_alt_name, null, null);
if (alternatives != null) {
auto count = sk_GENERAL_NAME_num(alternatives);
for (int i = 0; decision == Decision.SKIP && i < count; i++) {
auto name = sk_GENERAL_NAME_value(alternatives, i);
if (name is null) {
continue;
}
auto data = ASN1_STRING_data(name.d.ia5);
auto length = ASN1_STRING_length(name.d.ia5);
switch (name.type) {
case GENERAL_NAME.GEN_DNS:
decision = accessManager.verify(hostName, cast(char[])data[0 .. length]);
break;
case GENERAL_NAME.GEN_IPADD:
decision = accessManager.verify(peerAddress, data[0 .. length]);
break;
default:
// Do nothing.
}
}
// DMD @@BUG@@: Empty template arguments parens should not be needed.
sk_GENERAL_NAME_pop_free!()(alternatives, &GENERAL_NAME_free);
}
// If we are alredy done, return.
if (decision != Decision.SKIP) {
X509_free(cert);
if (decision != Decision.ALLOW) {
throw new TSSLException("Authorize: Access denied.");
}
return;
}
// Check commonName.
auto name = X509_get_subject_name(cert);
if (name !is null) {
X509_NAME_ENTRY* entry;
char* utf8;
int last = -1;
while (decision == Decision.SKIP) {
last = X509_NAME_get_index_by_NID(name, NID_commonName, last);
if (last == -1)
break;
entry = X509_NAME_get_entry(name, last);
if (entry is null)
continue;
auto common = X509_NAME_ENTRY_get_data(entry);
auto size = ASN1_STRING_to_UTF8(&utf8, common);
decision = accessManager.verify(hostName, utf8[0 .. size]);
CRYPTO_free(utf8);
}
}
X509_free(cert);
if (decision != Decision.ALLOW) {
throw new TSSLException("Authorize: Could not authorize peer.");
}
}
/*
* OpenSSL error information used for storing D exceptions on the OpenSSL
* error stack.
*/
enum ERR_LIB_D_EXCEPTION = ERR_LIB_USER;
enum ERR_F_D_EXCEPTION = 0; // function id - what to use here?
enum ERR_R_D_EXCEPTION = 1234; // 99 and above are reserved for applications
enum ERR_FILE_D_EXCEPTION = "d_exception";
enum ERR_LINE_D_EXCEPTION = 0;
enum ERR_FLAGS_D_EXCEPTION = 0;
/**
* Returns an exception for the last.
*
* Params:
* location = An optional "location" to add to the error message (typically
* the last SSL API call).
*/
Exception getSSLException(string location = null, string clientFile = __FILE__,
size_t clientLine = __LINE__
) {
// We can return either an exception saved from D BIO code, or a "true"
// OpenSSL error. Because there can possibly be more than one error on the
// error stack, we have to fetch all of them, and pick the last, i.e. newest
// one. We concatenate multiple successive OpenSSL error messages into a
// single one, but always just return the last D expcetion.
string message; // Probably better use an Appender here.
bool hadMessage;
Exception exception;
void initMessage() {
message.destroy();
hadMessage = false;
if (!location.empty) {
message ~= location;
message ~= ": ";
}
}
initMessage();
auto errn = errno;
const(char)* file = void;
int line = void;
const(char)* data = void;
int flags = void;
c_ulong code = void;
while ((code = ERR_get_error_line_data(&file, &line, &data, &flags)) != 0) {
if (ERR_GET_REASON(code) == ERR_R_D_EXCEPTION) {
initMessage();
GC.removeRoot(cast(void*)data);
exception = cast(Exception)data;
} else {
exception = null;
if (hadMessage) {
message ~= ", ";
}
auto reason = ERR_reason_error_string(code);
if (reason) {
message ~= "SSL error: " ~ to!string(reason);
} else {
message ~= "SSL error #" ~ to!string(code);
}
hadMessage = true;
}
}
// If the last item from the stack was a D exception, throw it.
if (exception) return exception;
// We are dealing with an OpenSSL error that doesn't root in a D exception.
if (!hadMessage) {
// If we didn't get an actual error from the stack yet, try errno.
string errnString;
if (errn != 0) {
errnString = to!string(strerror(errn));
}
if (errnString.empty) {
message ~= "Unknown error";
} else {
message ~= errnString;
}
}
message ~= ".";
return new TSSLException(message, clientFile, clientLine);
}

View file

@ -0,0 +1,190 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/**
* Provides a SSL BIO implementation wrapping a Thrift transport.
*
* This way, SSL I/O can be relayed over Thrift transport without introducing
* an additional layer of buffering, especially for the non-blocking
* transports.
*
* For the Thrift transport incarnations of the SSL entities, "tt" is used as
* prefix for clarity.
*/
module thrift.internal.ssl_bio;
import core.stdc.config;
import core.stdc.string : strlen;
import core.memory : GC;
import deimos.openssl.bio;
import deimos.openssl.err;
import thrift.base;
import thrift.internal.ssl;
import thrift.transport.base;
/**
* Creates an SSL BIO object wrapping the given transport.
*
* Exceptions thrown by the transport are pushed onto the OpenSSL error stack,
* using the location/reason values from thrift.internal.ssl.ERR_*_D_EXCEPTION.
*
* The transport is assumed to be ready for reading and writing when the BIO
* functions are called, it is not opened by the implementation.
*
* Params:
* transport = The transport to wrap.
* closeTransport = Whether the close the transport when the SSL BIO is
* closed.
*/
BIO* createTTransportBIO(TTransport transport, bool closeTransport) {
auto result = BIO_new(cast(BIO_METHOD*)&ttBioMethod);
if (!result) return null;
GC.addRoot(cast(void*)transport);
BIO_set_fd(result, closeTransport, cast(c_long)cast(void*)transport);
return result;
}
private {
// Helper to get the Thrift transport assigned with the given BIO.
TTransport trans(BIO* b) nothrow {
auto result = cast(TTransport)b.ptr;
assert(result);
return result;
}
void setError(Exception e) nothrow {
ERR_put_error(ERR_LIB_D_EXCEPTION, ERR_F_D_EXCEPTION, ERR_R_D_EXCEPTION,
ERR_FILE_D_EXCEPTION, ERR_LINE_D_EXCEPTION);
try { GC.addRoot(cast(void*)e); } catch {}
ERR_set_error_data(cast(char*)e, ERR_FLAGS_D_EXCEPTION);
}
extern(C) int ttWrite(BIO* b, const(char)* data, int length) nothrow {
assert(b);
if (!data || length <= 0) return 0;
try {
trans(b).write((cast(ubyte*)data)[0 .. length]);
return length;
} catch (Exception e) {
setError(e);
return -1;
}
}
extern(C) int ttRead(BIO* b, char* data, int length) nothrow {
assert(b);
if (!data || length <= 0) return 0;
try {
return cast(int)trans(b).read((cast(ubyte*)data)[0 .. length]);
} catch (Exception e) {
setError(e);
return -1;
}
}
extern(C) int ttPuts(BIO* b, const(char)* str) nothrow {
return ttWrite(b, str, cast(int)strlen(str));
}
extern(C) c_long ttCtrl(BIO* b, int cmd, c_long num, void* ptr) nothrow {
assert(b);
switch (cmd) {
case BIO_C_SET_FD:
// Note that close flag and "fd" are actually reversed here because we
// need 64 bit width for the pointer should probably drop BIO_set_fd
// altogether.
ttDestroy(b);
b.ptr = cast(void*)num;
b.shutdown = cast(int)ptr;
b.init_ = 1;
return 1;
case BIO_C_GET_FD:
if (!b.init_) return -1;
*(cast(void**)ptr) = b.ptr;
return cast(c_long)b.ptr;
case BIO_CTRL_GET_CLOSE:
return b.shutdown;
case BIO_CTRL_SET_CLOSE:
b.shutdown = cast(int)num;
return 1;
case BIO_CTRL_FLUSH:
try {
trans(b).flush();
return 1;
} catch (Exception e) {
setError(e);
return -1;
}
case BIO_CTRL_DUP:
// Seems like we have nothing to do on duplication, but couldn't find
// any documentation if this actually ever happens during normal SSL
// usage.
return 1;
default:
return 0;
}
}
extern(C) int ttCreate(BIO* b) nothrow {
assert(b);
b.init_ = 0;
b.num = 0; // User-defined number field, unused here.
b.ptr = null;
b.flags = 0;
return 1;
}
extern(C) int ttDestroy(BIO* b) nothrow {
if (!b) return 0;
int rc = 1;
if (b.shutdown) {
if (b.init_) {
try {
trans(b).close();
GC.removeRoot(cast(void*)trans(b));
b.ptr = null;
} catch (Exception e) {
setError(e);
rc = -1;
}
}
b.init_ = 0;
b.flags = 0;
}
return rc;
}
immutable BIO_METHOD ttBioMethod = {
BIO_TYPE_SOURCE_SINK,
"TTransport",
&ttWrite,
&ttRead,
&ttPuts,
null, // gets
&ttCtrl,
&ttCreate,
&ttDestroy,
null // callback_ctrl
};
}

View file

@ -0,0 +1,183 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.internal.test.protocol;
import std.exception;
import thrift.transport.memory;
import thrift.protocol.base;
version (unittest):
void testContainerSizeLimit(Protocol)() if (isTProtocol!Protocol) {
auto buffer = new TMemoryBuffer;
auto prot = new Protocol(buffer);
// Make sure reading fails if a container larger than the size limit is read.
prot.containerSizeLimit = 3;
{
prot.writeListBegin(TList(TType.I32, 4));
prot.writeI32(0); // Make sure size can be read e.g. for JSON protocol.
prot.reset();
auto e = cast(TProtocolException)collectException(prot.readListBegin());
enforce(e && e.type == TProtocolException.Type.SIZE_LIMIT);
prot.reset();
buffer.reset();
}
{
prot.writeMapBegin(TMap(TType.I32, TType.I32, 4));
prot.writeI32(0); // Make sure size can be read e.g. for JSON protocol.
prot.reset();
auto e = cast(TProtocolException)collectException(prot.readMapBegin());
enforce(e && e.type == TProtocolException.Type.SIZE_LIMIT);
prot.reset();
buffer.reset();
}
{
prot.writeSetBegin(TSet(TType.I32, 4));
prot.writeI32(0); // Make sure size can be read e.g. for JSON protocol.
prot.reset();
auto e = cast(TProtocolException)collectException(prot.readSetBegin());
enforce(e && e.type == TProtocolException.Type.SIZE_LIMIT);
prot.reset();
buffer.reset();
}
// Make sure reading works if the containers are smaller than the limit or
// no limit is set.
foreach (limit; [3, 0, -1]) {
prot.containerSizeLimit = limit;
{
prot.writeListBegin(TList(TType.I32, 2));
prot.writeI32(0);
prot.writeI32(1);
prot.writeListEnd();
prot.reset();
auto list = prot.readListBegin();
enforce(list.elemType == TType.I32);
enforce(list.size == 2);
enforce(prot.readI32() == 0);
enforce(prot.readI32() == 1);
prot.readListEnd();
prot.reset();
buffer.reset();
}
{
prot.writeMapBegin(TMap(TType.I32, TType.I32, 2));
prot.writeI32(0);
prot.writeI32(1);
prot.writeI32(2);
prot.writeI32(3);
prot.writeMapEnd();
prot.reset();
auto map = prot.readMapBegin();
enforce(map.keyType == TType.I32);
enforce(map.valueType == TType.I32);
enforce(map.size == 2);
enforce(prot.readI32() == 0);
enforce(prot.readI32() == 1);
enforce(prot.readI32() == 2);
enforce(prot.readI32() == 3);
prot.readMapEnd();
prot.reset();
buffer.reset();
}
{
prot.writeSetBegin(TSet(TType.I32, 2));
prot.writeI32(0);
prot.writeI32(1);
prot.writeSetEnd();
prot.reset();
auto set = prot.readSetBegin();
enforce(set.elemType == TType.I32);
enforce(set.size == 2);
enforce(prot.readI32() == 0);
enforce(prot.readI32() == 1);
prot.readSetEnd();
prot.reset();
buffer.reset();
}
}
}
void testStringSizeLimit(Protocol)() if (isTProtocol!Protocol) {
auto buffer = new TMemoryBuffer;
auto prot = new Protocol(buffer);
// Make sure reading fails if a string larger than the size limit is read.
prot.stringSizeLimit = 3;
{
prot.writeString("asdf");
prot.reset();
auto e = cast(TProtocolException)collectException(prot.readString());
enforce(e && e.type == TProtocolException.Type.SIZE_LIMIT);
prot.reset();
buffer.reset();
}
{
prot.writeBinary([1, 2, 3, 4]);
prot.reset();
auto e = cast(TProtocolException)collectException(prot.readBinary());
enforce(e && e.type == TProtocolException.Type.SIZE_LIMIT);
prot.reset();
buffer.reset();
}
// Make sure reading works if the containers are smaller than the limit or
// no limit is set.
foreach (limit; [3, 0, -1]) {
prot.containerSizeLimit = limit;
{
prot.writeString("as");
prot.reset();
enforce(prot.readString() == "as");
prot.reset();
buffer.reset();
}
{
prot.writeBinary([1, 2]);
prot.reset();
enforce(prot.readBinary() == [1, 2]);
prot.reset();
buffer.reset();
}
}
}

View file

@ -0,0 +1,110 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.internal.test.server;
import core.sync.condition;
import core.sync.mutex;
import core.thread : Thread;
import std.datetime;
import std.exception : enforce;
import std.typecons : WhiteHole;
import std.variant : Variant;
import thrift.protocol.base;
import thrift.protocol.binary;
import thrift.protocol.processor;
import thrift.server.base;
import thrift.server.transport.socket;
import thrift.transport.base;
import thrift.util.cancellation;
version(unittest):
/**
* Tests if serving is stopped correctly if the cancellation passed to serve()
* is triggered.
*
* Because the tests are run many times in a loop, this is indirectly also a
* test whether socket, etc. handles are cleaned up correctly, because the
* application will likely run out of handles otherwise.
*/
void testServeCancel(Server)(void delegate(Server) serverSetup = null) if (
is(Server : TServer)
) {
auto proc = new WhiteHole!TProcessor;
auto tf = new TTransportFactory;
auto pf = new TBinaryProtocolFactory!();
// Need a special case for TNonblockingServer which doesn't use
// TServerTransport.
static if (__traits(compiles, new Server(proc, 0, tf, pf))) {
auto server = new Server(proc, 0, tf, pf);
} else {
auto server = new Server(proc, new TServerSocket(0), tf, pf);
}
// On Windows, we use TCP sockets to replace socketpair(). Since they stay
// in TIME_WAIT for some time even if they are properly closed, we have to use
// a lower number of iterations to avoid running out of ports/buffer space.
version (Windows) {
enum ITERATIONS = 100;
} else {
enum ITERATIONS = 10000;
}
if (serverSetup) serverSetup(server);
auto servingMutex = new Mutex;
auto servingCondition = new Condition(servingMutex);
auto doneMutex = new Mutex;
auto doneCondition = new Condition(doneMutex);
class CancellingHandler : TServerEventHandler {
void preServe() {
synchronized (servingMutex) {
servingCondition.notifyAll();
}
}
Variant createContext(TProtocol input, TProtocol output) { return Variant.init; }
void deleteContext(Variant serverContext, TProtocol input, TProtocol output) {}
void preProcess(Variant serverContext, TTransport transport) {}
}
server.eventHandler = new CancellingHandler;
foreach (i; 0 .. ITERATIONS) {
synchronized (servingMutex) {
auto cancel = new TCancellationOrigin;
synchronized (doneMutex) {
auto serverThread = new Thread({
server.serve(cancel);
synchronized (doneMutex) {
doneCondition.notifyAll();
}
});
serverThread.isDaemon = true;
serverThread.start();
servingCondition.wait();
cancel.trigger();
enforce(doneCondition.wait(dur!"msecs"(3*1000)));
serverThread.join();
}
}
}
}

View file

@ -0,0 +1,33 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.internal.traits;
import std.traits;
/**
* Adds »nothrow« to the type of the passed function pointer/delegate, if it
* is not already present.
*
* Technically, assumeNothrow just performs a cast, but using it has the
* advantage of being explicitly about the operation that is performed.
*/
auto assumeNothrow(T)(T t) if (isFunctionPointer!T || isDelegate!T) {
enum attrs = functionAttributes!T | FunctionAttribute.nothrow_;
return cast(SetFunctionAttributes!(T, functionLinkage!T, attrs)) t;
}

View file

@ -0,0 +1,449 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/**
* Defines the basic interface for a Thrift protocol and associated exception
* types.
*
* Most parts of the protocol API are typically not used in client code, as
* the actual serialization code is generated by thrift.codegen.* the only
* interesting thing usually is that there are protocols which can be created
* from transports and passed around.
*/
module thrift.protocol.base;
import thrift.base;
import thrift.transport.base;
/**
* The field types Thrift protocols support.
*/
enum TType : byte {
STOP = 0, /// Used to mark the end of a sequence of fields.
VOID = 1, ///
BOOL = 2, ///
BYTE = 3, ///
DOUBLE = 4, ///
I16 = 6, ///
I32 = 8, ///
I64 = 10, ///
STRING = 11, ///
STRUCT = 12, ///
MAP = 13, ///
SET = 14, ///
LIST = 15 ///
}
/**
* Types of Thrift RPC messages.
*/
enum TMessageType : byte {
CALL = 1, /// Call of a normal, two-way RPC method.
REPLY = 2, /// Reply to a normal method call.
EXCEPTION = 3, /// Reply to a method call if target raised a TApplicationException.
ONEWAY = 4 /// Call of a one-way RPC method which is not followed by a reply.
}
/**
* Descriptions of Thrift entities.
*/
struct TField {
string name;
TType type;
short id;
}
/// ditto
struct TList {
TType elemType;
size_t size;
}
/// ditto
struct TMap {
TType keyType;
TType valueType;
size_t size;
}
/// ditto
struct TMessage {
string name;
TMessageType type;
int seqid;
}
/// ditto
struct TSet {
TType elemType;
size_t size;
}
/// ditto
struct TStruct {
string name;
}
/**
* Interface for a Thrift protocol implementation. Essentially, it defines
* a way of reading and writing all the base types, plus a mechanism for
* writing out structs with indexed fields.
*
* TProtocol objects should not be shared across multiple encoding contexts,
* as they may need to maintain internal state in some protocols (e.g. JSON).
* Note that is is acceptable for the TProtocol module to do its own internal
* buffered reads/writes to the underlying TTransport where appropriate (i.e.
* when parsing an input XML stream, reading could be batched rather than
* looking ahead character by character for a close tag).
*/
interface TProtocol {
/// The underlying transport used by the protocol.
TTransport transport() @property;
/*
* Writing methods.
*/
void writeBool(bool b); ///
void writeByte(byte b); ///
void writeI16(short i16); ///
void writeI32(int i32); ///
void writeI64(long i64); ///
void writeDouble(double dub); ///
void writeString(string str); ///
void writeBinary(ubyte[] buf); ///
void writeMessageBegin(TMessage message); ///
void writeMessageEnd(); ///
void writeStructBegin(TStruct tstruct); ///
void writeStructEnd(); ///
void writeFieldBegin(TField field); ///
void writeFieldEnd(); ///
void writeFieldStop(); ///
void writeListBegin(TList list); ///
void writeListEnd(); ///
void writeMapBegin(TMap map); ///
void writeMapEnd(); ///
void writeSetBegin(TSet set); ///
void writeSetEnd(); ///
/*
* Reading methods.
*/
bool readBool(); ///
byte readByte(); ///
short readI16(); ///
int readI32(); ///
long readI64(); ///
double readDouble(); ///
string readString(); ///
ubyte[] readBinary(); ///
TMessage readMessageBegin(); ///
void readMessageEnd(); ///
TStruct readStructBegin(); ///
void readStructEnd(); ///
TField readFieldBegin(); ///
void readFieldEnd(); ///
TList readListBegin(); ///
void readListEnd(); ///
TMap readMapBegin(); ///
void readMapEnd(); ///
TSet readSetBegin(); ///
void readSetEnd(); ///
/**
* Reset any internal state back to a blank slate, if the protocol is
* stateful.
*/
void reset();
}
/**
* true if T is a TProtocol.
*/
template isTProtocol(T) {
enum isTProtocol = is(T : TProtocol);
}
unittest {
static assert(isTProtocol!TProtocol);
static assert(!isTProtocol!void);
}
/**
* Creates a protocol operating on a given transport.
*/
interface TProtocolFactory {
///
TProtocol getProtocol(TTransport trans);
}
/**
* A protocol-level exception.
*/
class TProtocolException : TException {
/// The possible exception types.
enum Type {
UNKNOWN, ///
INVALID_DATA, ///
NEGATIVE_SIZE, ///
SIZE_LIMIT, ///
BAD_VERSION, ///
NOT_IMPLEMENTED, ///
DEPTH_LIMIT ///
}
///
this(Type type, string file = __FILE__, size_t line = __LINE__, Throwable next = null) {
static string msgForType(Type type) {
switch (type) {
case Type.UNKNOWN: return "Unknown protocol exception";
case Type.INVALID_DATA: return "Invalid data";
case Type.NEGATIVE_SIZE: return "Negative size";
case Type.SIZE_LIMIT: return "Exceeded size limit";
case Type.BAD_VERSION: return "Invalid version";
case Type.NOT_IMPLEMENTED: return "Not implemented";
case Type.DEPTH_LIMIT: return "Exceeded size limit";
default: return "(Invalid exception type)";
}
}
this(msgForType(type), type, file, line, next);
}
///
this(string msg, string file = __FILE__, size_t line = __LINE__,
Throwable next = null)
{
this(msg, Type.UNKNOWN, file, line, next);
}
///
this(string msg, Type type, string file = __FILE__, size_t line = __LINE__,
Throwable next = null)
{
super(msg, file, line, next);
type_ = type;
}
///
Type type() const @property {
return type_;
}
protected:
Type type_;
}
/**
* Skips a field of the given type on the protocol.
*
* The main purpose of skip() is to allow treating struct and container types,
* (where multiple primitive types have to be skipped) the same as scalar types
* in generated code.
*/
void skip(Protocol)(Protocol prot, TType type) if (is(Protocol : TProtocol)) {
final switch (type) {
case TType.BOOL:
prot.readBool();
break;
case TType.BYTE:
prot.readByte();
break;
case TType.I16:
prot.readI16();
break;
case TType.I32:
prot.readI32();
break;
case TType.I64:
prot.readI64();
break;
case TType.DOUBLE:
prot.readDouble();
break;
case TType.STRING:
prot.readBinary();
break;
case TType.STRUCT:
prot.readStructBegin();
while (true) {
auto f = prot.readFieldBegin();
if (f.type == TType.STOP) break;
skip(prot, f.type);
prot.readFieldEnd();
}
prot.readStructEnd();
break;
case TType.LIST:
auto l = prot.readListBegin();
foreach (i; 0 .. l.size) {
skip(prot, l.elemType);
}
prot.readListEnd();
break;
case TType.MAP:
auto m = prot.readMapBegin();
foreach (i; 0 .. m.size) {
skip(prot, m.keyType);
skip(prot, m.valueType);
}
prot.readMapEnd();
break;
case TType.SET:
auto s = prot.readSetBegin();
foreach (i; 0 .. s.size) {
skip(prot, s.elemType);
}
prot.readSetEnd();
break;
case TType.STOP: goto case;
case TType.VOID:
assert(false, "Invalid field type passed.");
}
}
/**
* Application-level exception.
*
* It is thrown if an RPC call went wrong on the application layer, e.g. if
* the receiver does not know the method name requested or a method invoked by
* the service processor throws an exception not part of the Thrift API.
*/
class TApplicationException : TException {
/// The possible exception types.
enum Type {
UNKNOWN = 0, ///
UNKNOWN_METHOD = 1, ///
INVALID_MESSAGE_TYPE = 2, ///
WRONG_METHOD_NAME = 3, ///
BAD_SEQUENCE_ID = 4, ///
MISSING_RESULT = 5, ///
INTERNAL_ERROR = 6, ///
PROTOCOL_ERROR = 7, ///
INVALID_TRANSFORM = 8, ///
INVALID_PROTOCOL = 9, ///
UNSUPPORTED_CLIENT_TYPE = 10 ///
}
///
this(Type type, string file = __FILE__, size_t line = __LINE__, Throwable next = null) {
static string msgForType(Type type) {
switch (type) {
case Type.UNKNOWN: return "Unknown application exception";
case Type.UNKNOWN_METHOD: return "Unknown method";
case Type.INVALID_MESSAGE_TYPE: return "Invalid message type";
case Type.WRONG_METHOD_NAME: return "Wrong method name";
case Type.BAD_SEQUENCE_ID: return "Bad sequence identifier";
case Type.MISSING_RESULT: return "Missing result";
case Type.INTERNAL_ERROR: return "Internal error";
case Type.PROTOCOL_ERROR: return "Protocol error";
case Type.INVALID_TRANSFORM: return "Invalid transform";
case Type.INVALID_PROTOCOL: return "Invalid protocol";
case Type.UNSUPPORTED_CLIENT_TYPE: return "Unsupported client type";
default: return "(Invalid exception type)";
}
}
this(msgForType(type), type, file, line, next);
}
///
this(string msg, string file = __FILE__, size_t line = __LINE__,
Throwable next = null)
{
this(msg, Type.UNKNOWN, file, line, next);
}
///
this(string msg, Type type, string file = __FILE__, size_t line = __LINE__,
Throwable next = null)
{
super(msg, file, line, next);
type_ = type;
}
///
Type type() @property const {
return type_;
}
// TODO: Replace hand-written read()/write() with thrift.codegen templates.
///
void read(TProtocol iprot) {
iprot.readStructBegin();
while (true) {
auto f = iprot.readFieldBegin();
if (f.type == TType.STOP) break;
switch (f.id) {
case 1:
if (f.type == TType.STRING) {
msg = iprot.readString();
} else {
skip(iprot, f.type);
}
break;
case 2:
if (f.type == TType.I32) {
type_ = cast(Type)iprot.readI32();
} else {
skip(iprot, f.type);
}
break;
default:
skip(iprot, f.type);
break;
}
}
iprot.readStructEnd();
}
///
void write(TProtocol oprot) const {
oprot.writeStructBegin(TStruct("TApplicationException"));
if (msg != null) {
oprot.writeFieldBegin(TField("message", TType.STRING, 1));
oprot.writeString(msg);
oprot.writeFieldEnd();
}
oprot.writeFieldBegin(TField("type", TType.I32, 2));
oprot.writeI32(type_);
oprot.writeFieldEnd();
oprot.writeFieldStop();
oprot.writeStructEnd();
}
private:
Type type_;
}

View file

@ -0,0 +1,414 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.protocol.binary;
import std.array : uninitializedArray;
import std.typetuple : allSatisfy, TypeTuple;
import thrift.protocol.base;
import thrift.transport.base;
import thrift.internal.endian;
/**
* TProtocol implementation of the Binary Thrift protocol.
*/
final class TBinaryProtocol(Transport = TTransport) if (
isTTransport!Transport
) : TProtocol {
/**
* Constructs a new instance.
*
* Params:
* trans = The transport to use.
* containerSizeLimit = If positive, the container size is limited to the
* given number of items.
* stringSizeLimit = If positive, the string length is limited to the
* given number of bytes.
* strictRead = If false, old peers which do not include the protocol
* version are tolerated.
* strictWrite = Whether to include the protocol version in the header.
*/
this(Transport trans, int containerSizeLimit = 0, int stringSizeLimit = 0,
bool strictRead = false, bool strictWrite = true
) {
trans_ = trans;
this.containerSizeLimit = containerSizeLimit;
this.stringSizeLimit = stringSizeLimit;
this.strictRead = strictRead;
this.strictWrite = strictWrite;
}
Transport transport() @property {
return trans_;
}
void reset() {}
/**
* If false, old peers which do not include the protocol version in the
* message header are tolerated.
*
* Defaults to false.
*/
bool strictRead;
/**
* Whether to include the protocol version in the message header (older
* versions didn't).
*
* Defaults to true.
*/
bool strictWrite;
/**
* If positive, limits the number of items of deserialized containers to the
* given amount.
*
* This is useful to avoid allocating excessive amounts of memory when broken
* data is received. If the limit is exceeded, a SIZE_LIMIT-type
* TProtocolException is thrown.
*
* Defaults to zero (no limit).
*/
int containerSizeLimit;
/**
* If positive, limits the length of deserialized strings/binary data to the
* given number of bytes.
*
* This is useful to avoid allocating excessive amounts of memory when broken
* data is received. If the limit is exceeded, a SIZE_LIMIT-type
* TProtocolException is thrown.
*
* Defaults to zero (no limit).
*/
int stringSizeLimit;
/*
* Writing methods.
*/
void writeBool(bool b) {
writeByte(b ? 1 : 0);
}
void writeByte(byte b) {
trans_.write((cast(ubyte*)&b)[0 .. 1]);
}
void writeI16(short i16) {
short net = hostToNet(i16);
trans_.write((cast(ubyte*)&net)[0 .. 2]);
}
void writeI32(int i32) {
int net = hostToNet(i32);
trans_.write((cast(ubyte*)&net)[0 .. 4]);
}
void writeI64(long i64) {
long net = hostToNet(i64);
trans_.write((cast(ubyte*)&net)[0 .. 8]);
}
void writeDouble(double dub) {
static assert(double.sizeof == ulong.sizeof);
auto bits = hostToNet(*cast(ulong*)(&dub));
trans_.write((cast(ubyte*)&bits)[0 .. 8]);
}
void writeString(string str) {
writeBinary(cast(ubyte[])str);
}
void writeBinary(ubyte[] buf) {
assert(buf.length <= int.max);
writeI32(cast(int)buf.length);
trans_.write(buf);
}
void writeMessageBegin(TMessage message) {
if (strictWrite) {
int versn = VERSION_1 | message.type;
writeI32(versn);
writeString(message.name);
writeI32(message.seqid);
} else {
writeString(message.name);
writeByte(message.type);
writeI32(message.seqid);
}
}
void writeMessageEnd() {}
void writeStructBegin(TStruct tstruct) {}
void writeStructEnd() {}
void writeFieldBegin(TField field) {
writeByte(field.type);
writeI16(field.id);
}
void writeFieldEnd() {}
void writeFieldStop() {
writeByte(TType.STOP);
}
void writeListBegin(TList list) {
assert(list.size <= int.max);
writeByte(list.elemType);
writeI32(cast(int)list.size);
}
void writeListEnd() {}
void writeMapBegin(TMap map) {
assert(map.size <= int.max);
writeByte(map.keyType);
writeByte(map.valueType);
writeI32(cast(int)map.size);
}
void writeMapEnd() {}
void writeSetBegin(TSet set) {
assert(set.size <= int.max);
writeByte(set.elemType);
writeI32(cast(int)set.size);
}
void writeSetEnd() {}
/*
* Reading methods.
*/
bool readBool() {
return readByte() != 0;
}
byte readByte() {
ubyte[1] b = void;
trans_.readAll(b);
return cast(byte)b[0];
}
short readI16() {
IntBuf!short b = void;
trans_.readAll(b.bytes);
return netToHost(b.value);
}
int readI32() {
IntBuf!int b = void;
trans_.readAll(b.bytes);
return netToHost(b.value);
}
long readI64() {
IntBuf!long b = void;
trans_.readAll(b.bytes);
return netToHost(b.value);
}
double readDouble() {
IntBuf!long b = void;
trans_.readAll(b.bytes);
b.value = netToHost(b.value);
return *cast(double*)(&b.value);
}
string readString() {
return cast(string)readBinary();
}
ubyte[] readBinary() {
return readBinaryBody(readSize(stringSizeLimit));
}
TMessage readMessageBegin() {
TMessage msg = void;
int size = readI32();
if (size < 0) {
int versn = size & VERSION_MASK;
if (versn != VERSION_1) {
throw new TProtocolException("Bad protocol version.",
TProtocolException.Type.BAD_VERSION);
}
msg.type = cast(TMessageType)(size & MESSAGE_TYPE_MASK);
msg.name = readString();
msg.seqid = readI32();
} else {
if (strictRead) {
throw new TProtocolException(
"Protocol version missing, old client?",
TProtocolException.Type.BAD_VERSION);
} else {
if (size < 0) {
throw new TProtocolException(TProtocolException.Type.NEGATIVE_SIZE);
}
msg.name = cast(string)readBinaryBody(size);
msg.type = cast(TMessageType)(readByte());
msg.seqid = readI32();
}
}
return msg;
}
void readMessageEnd() {}
TStruct readStructBegin() {
return TStruct();
}
void readStructEnd() {}
TField readFieldBegin() {
TField f = void;
f.name = null;
f.type = cast(TType)readByte();
if (f.type == TType.STOP) return f;
f.id = readI16();
return f;
}
void readFieldEnd() {}
TList readListBegin() {
return TList(cast(TType)readByte(), readSize(containerSizeLimit));
}
void readListEnd() {}
TMap readMapBegin() {
return TMap(cast(TType)readByte(), cast(TType)readByte(),
readSize(containerSizeLimit));
}
void readMapEnd() {}
TSet readSetBegin() {
return TSet(cast(TType)readByte(), readSize(containerSizeLimit));
}
void readSetEnd() {}
private:
ubyte[] readBinaryBody(int size) {
if (size == 0) {
return null;
}
auto buf = uninitializedArray!(ubyte[])(size);
trans_.readAll(buf);
return buf;
}
int readSize(int limit) {
auto size = readI32();
if (size < 0) {
throw new TProtocolException(TProtocolException.Type.NEGATIVE_SIZE);
} else if (limit > 0 && size > limit) {
throw new TProtocolException(TProtocolException.Type.SIZE_LIMIT);
}
return size;
}
enum MESSAGE_TYPE_MASK = 0x000000ff;
enum VERSION_MASK = 0xffff0000;
enum VERSION_1 = 0x80010000;
Transport trans_;
}
/**
* TBinaryProtocol construction helper to avoid having to explicitly specify
* the transport type, i.e. to allow the constructor being called using IFTI
* (see $(LINK2 http://d.puremagic.com/issues/show_bug.cgi?id=6082, D Bugzilla
* enhancement requet 6082)).
*/
TBinaryProtocol!Transport tBinaryProtocol(Transport)(Transport trans,
int containerSizeLimit = 0, int stringSizeLimit = 0,
bool strictRead = false, bool strictWrite = true
) if (isTTransport!Transport) {
return new TBinaryProtocol!Transport(trans, containerSizeLimit,
stringSizeLimit, strictRead, strictWrite);
}
unittest {
import std.exception;
import thrift.transport.memory;
// Check the message header format.
auto buf = new TMemoryBuffer;
auto binary = tBinaryProtocol(buf);
binary.writeMessageBegin(TMessage("foo", TMessageType.CALL, 0));
auto header = new ubyte[15];
buf.readAll(header);
enforce(header == [
128, 1, 0, 1, // Version 1, TMessageType.CALL
0, 0, 0, 3, // Method name length
102, 111, 111, // Method name ("foo")
0, 0, 0, 0, // Sequence id
]);
}
unittest {
import thrift.internal.test.protocol;
testContainerSizeLimit!(TBinaryProtocol!())();
testStringSizeLimit!(TBinaryProtocol!())();
}
/**
* TProtocolFactory creating a TBinaryProtocol instance for passed in
* transports.
*
* The optional Transports template tuple parameter can be used to specify
* one or more TTransport implementations to specifically instantiate
* TBinaryProtocol for. If the actual transport types encountered at
* runtime match one of the transports in the list, a specialized protocol
* instance is created. Otherwise, a generic TTransport version is used.
*/
class TBinaryProtocolFactory(Transports...) if (
allSatisfy!(isTTransport, Transports)
) : TProtocolFactory {
///
this (int containerSizeLimit = 0, int stringSizeLimit = 0,
bool strictRead = false, bool strictWrite = true
) {
strictRead_ = strictRead;
strictWrite_ = strictWrite;
containerSizeLimit_ = containerSizeLimit;
stringSizeLimit_ = stringSizeLimit;
}
TProtocol getProtocol(TTransport trans) const {
foreach (Transport; TypeTuple!(Transports, TTransport)) {
auto concreteTrans = cast(Transport)trans;
if (concreteTrans) {
return new TBinaryProtocol!Transport(concreteTrans,
containerSizeLimit_, stringSizeLimit_, strictRead_, strictWrite_);
}
}
throw new TProtocolException(
"Passed null transport to TBinaryProtocolFactoy.");
}
protected:
bool strictRead_;
bool strictWrite_;
int containerSizeLimit_;
int stringSizeLimit_;
}

View file

@ -0,0 +1,698 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.protocol.compact;
import std.array : uninitializedArray;
import std.typetuple : allSatisfy, TypeTuple;
import thrift.protocol.base;
import thrift.transport.base;
import thrift.internal.endian;
/**
* D implementation of the Compact protocol.
*
* See THRIFT-110 for a protocol description. This implementation is based on
* the C++ one.
*/
final class TCompactProtocol(Transport = TTransport) if (
isTTransport!Transport
) : TProtocol {
/**
* Constructs a new instance.
*
* Params:
* trans = The transport to use.
* containerSizeLimit = If positive, the container size is limited to the
* given number of items.
* stringSizeLimit = If positive, the string length is limited to the
* given number of bytes.
*/
this(Transport trans, int containerSizeLimit = 0, int stringSizeLimit = 0) {
trans_ = trans;
this.containerSizeLimit = containerSizeLimit;
this.stringSizeLimit = stringSizeLimit;
}
Transport transport() @property {
return trans_;
}
void reset() {
lastFieldId_ = 0;
fieldIdStack_ = null;
booleanField_ = TField.init;
hasBoolValue_ = false;
}
/**
* If positive, limits the number of items of deserialized containers to the
* given amount.
*
* This is useful to avoid allocating excessive amounts of memory when broken
* data is received. If the limit is exceeded, a SIZE_LIMIT-type
* TProtocolException is thrown.
*
* Defaults to zero (no limit).
*/
int containerSizeLimit;
/**
* If positive, limits the length of deserialized strings/binary data to the
* given number of bytes.
*
* This is useful to avoid allocating excessive amounts of memory when broken
* data is received. If the limit is exceeded, a SIZE_LIMIT-type
* TProtocolException is thrown.
*
* Defaults to zero (no limit).
*/
int stringSizeLimit;
/*
* Writing methods.
*/
void writeBool(bool b) {
if (booleanField_.name !is null) {
// we haven't written the field header yet
writeFieldBeginInternal(booleanField_,
b ? CType.BOOLEAN_TRUE : CType.BOOLEAN_FALSE);
booleanField_.name = null;
} else {
// we're not part of a field, so just write the value
writeByte(b ? CType.BOOLEAN_TRUE : CType.BOOLEAN_FALSE);
}
}
void writeByte(byte b) {
trans_.write((cast(ubyte*)&b)[0..1]);
}
void writeI16(short i16) {
writeVarint32(i32ToZigzag(i16));
}
void writeI32(int i32) {
writeVarint32(i32ToZigzag(i32));
}
void writeI64(long i64) {
writeVarint64(i64ToZigzag(i64));
}
void writeDouble(double dub) {
ulong bits = hostToLe(*cast(ulong*)(&dub));
trans_.write((cast(ubyte*)&bits)[0 .. 8]);
}
void writeString(string str) {
writeBinary(cast(ubyte[])str);
}
void writeBinary(ubyte[] buf) {
assert(buf.length <= int.max);
writeVarint32(cast(int)buf.length);
trans_.write(buf);
}
void writeMessageBegin(TMessage msg) {
writeByte(cast(byte)PROTOCOL_ID);
writeByte(cast(byte)((VERSION_N & VERSION_MASK) |
((cast(int)msg.type << TYPE_SHIFT_AMOUNT) & TYPE_MASK)));
writeVarint32(msg.seqid);
writeString(msg.name);
}
void writeMessageEnd() {}
void writeStructBegin(TStruct tstruct) {
fieldIdStack_ ~= lastFieldId_;
lastFieldId_ = 0;
}
void writeStructEnd() {
lastFieldId_ = fieldIdStack_[$ - 1];
fieldIdStack_ = fieldIdStack_[0 .. $ - 1];
fieldIdStack_.assumeSafeAppend();
}
void writeFieldBegin(TField field) {
if (field.type == TType.BOOL) {
booleanField_.name = field.name;
booleanField_.type = field.type;
booleanField_.id = field.id;
} else {
return writeFieldBeginInternal(field);
}
}
void writeFieldEnd() {}
void writeFieldStop() {
writeByte(TType.STOP);
}
void writeListBegin(TList list) {
writeCollectionBegin(list.elemType, list.size);
}
void writeListEnd() {}
void writeMapBegin(TMap map) {
if (map.size == 0) {
writeByte(0);
} else {
assert(map.size <= int.max);
writeVarint32(cast(int)map.size);
writeByte(cast(byte)(toCType(map.keyType) << 4 | toCType(map.valueType)));
}
}
void writeMapEnd() {}
void writeSetBegin(TSet set) {
writeCollectionBegin(set.elemType, set.size);
}
void writeSetEnd() {}
/*
* Reading methods.
*/
bool readBool() {
if (hasBoolValue_ == true) {
hasBoolValue_ = false;
return boolValue_;
}
return readByte() == CType.BOOLEAN_TRUE;
}
byte readByte() {
ubyte[1] b = void;
trans_.readAll(b);
return cast(byte)b[0];
}
short readI16() {
return cast(short)zigzagToI32(readVarint32());
}
int readI32() {
return zigzagToI32(readVarint32());
}
long readI64() {
return zigzagToI64(readVarint64());
}
double readDouble() {
IntBuf!long b = void;
trans_.readAll(b.bytes);
b.value = leToHost(b.value);
return *cast(double*)(&b.value);
}
string readString() {
return cast(string)readBinary();
}
ubyte[] readBinary() {
auto size = readVarint32();
checkSize(size, stringSizeLimit);
if (size == 0) {
return null;
}
auto buf = uninitializedArray!(ubyte[])(size);
trans_.readAll(buf);
return buf;
}
TMessage readMessageBegin() {
TMessage msg = void;
auto protocolId = readByte();
if (protocolId != cast(byte)PROTOCOL_ID) {
throw new TProtocolException("Bad protocol identifier",
TProtocolException.Type.BAD_VERSION);
}
auto versionAndType = readByte();
auto ver = versionAndType & VERSION_MASK;
if (ver != VERSION_N) {
throw new TProtocolException("Bad protocol version",
TProtocolException.Type.BAD_VERSION);
}
msg.type = cast(TMessageType)((versionAndType >> TYPE_SHIFT_AMOUNT) & TYPE_BITS);
msg.seqid = readVarint32();
msg.name = readString();
return msg;
}
void readMessageEnd() {}
TStruct readStructBegin() {
fieldIdStack_ ~= lastFieldId_;
lastFieldId_ = 0;
return TStruct();
}
void readStructEnd() {
lastFieldId_ = fieldIdStack_[$ - 1];
fieldIdStack_ = fieldIdStack_[0 .. $ - 1];
}
TField readFieldBegin() {
TField f = void;
f.name = null;
auto bite = readByte();
auto type = cast(CType)(bite & 0x0f);
if (type == CType.STOP) {
// Struct stop byte, nothing more to do.
f.id = 0;
f.type = TType.STOP;
return f;
}
// Mask off the 4 MSB of the type header, which could contain a field id
// delta.
auto modifier = cast(short)((bite & 0xf0) >> 4);
if (modifier > 0) {
f.id = cast(short)(lastFieldId_ + modifier);
} else {
// Delta encoding not used, just read the id as usual.
f.id = readI16();
}
f.type = getTType(type);
if (type == CType.BOOLEAN_TRUE || type == CType.BOOLEAN_FALSE) {
// For boolean fields, the value is encoded in the type keep it around
// for the readBool() call.
hasBoolValue_ = true;
boolValue_ = (type == CType.BOOLEAN_TRUE ? true : false);
}
lastFieldId_ = f.id;
return f;
}
void readFieldEnd() {}
TList readListBegin() {
auto sizeAndType = readByte();
auto lsize = (sizeAndType >> 4) & 0xf;
if (lsize == 0xf) {
lsize = readVarint32();
}
checkSize(lsize, containerSizeLimit);
TList l = void;
l.elemType = getTType(cast(CType)(sizeAndType & 0x0f));
l.size = cast(size_t)lsize;
return l;
}
void readListEnd() {}
TMap readMapBegin() {
TMap m = void;
auto size = readVarint32();
ubyte kvType;
if (size != 0) {
kvType = readByte();
}
checkSize(size, containerSizeLimit);
m.size = size;
m.keyType = getTType(cast(CType)(kvType >> 4));
m.valueType = getTType(cast(CType)(kvType & 0xf));
return m;
}
void readMapEnd() {}
TSet readSetBegin() {
auto sizeAndType = readByte();
auto lsize = (sizeAndType >> 4) & 0xf;
if (lsize == 0xf) {
lsize = readVarint32();
}
checkSize(lsize, containerSizeLimit);
TSet s = void;
s.elemType = getTType(cast(CType)(sizeAndType & 0xf));
s.size = cast(size_t)lsize;
return s;
}
void readSetEnd() {}
private:
void writeFieldBeginInternal(TField field, byte typeOverride = -1) {
// If there's a type override, use that.
auto typeToWrite = (typeOverride == -1 ? toCType(field.type) : typeOverride);
// check if we can use delta encoding for the field id
if (field.id > lastFieldId_ && (field.id - lastFieldId_) <= 15) {
// write them together
writeByte(cast(byte)((field.id - lastFieldId_) << 4 | typeToWrite));
} else {
// write them separate
writeByte(cast(byte)typeToWrite);
writeI16(field.id);
}
lastFieldId_ = field.id;
}
void writeCollectionBegin(TType elemType, size_t size) {
if (size <= 14) {
writeByte(cast(byte)(size << 4 | toCType(elemType)));
} else {
assert(size <= int.max);
writeByte(cast(byte)(0xf0 | toCType(elemType)));
writeVarint32(cast(int)size);
}
}
void writeVarint32(uint n) {
ubyte[5] buf = void;
ubyte wsize;
while (true) {
if ((n & ~0x7F) == 0) {
buf[wsize++] = cast(ubyte)n;
break;
} else {
buf[wsize++] = cast(ubyte)((n & 0x7F) | 0x80);
n >>= 7;
}
}
trans_.write(buf[0 .. wsize]);
}
/*
* Write an i64 as a varint. Results in 1-10 bytes on the wire.
*/
void writeVarint64(ulong n) {
ubyte[10] buf = void;
ubyte wsize;
while (true) {
if ((n & ~0x7FL) == 0) {
buf[wsize++] = cast(ubyte)n;
break;
} else {
buf[wsize++] = cast(ubyte)((n & 0x7F) | 0x80);
n >>= 7;
}
}
trans_.write(buf[0 .. wsize]);
}
/*
* Convert l into a zigzag long. This allows negative numbers to be
* represented compactly as a varint.
*/
ulong i64ToZigzag(long l) {
return (l << 1) ^ (l >> 63);
}
/*
* Convert n into a zigzag int. This allows negative numbers to be
* represented compactly as a varint.
*/
uint i32ToZigzag(int n) {
return (n << 1) ^ (n >> 31);
}
CType toCType(TType type) {
final switch (type) {
case TType.STOP:
return CType.STOP;
case TType.BOOL:
return CType.BOOLEAN_TRUE;
case TType.BYTE:
return CType.BYTE;
case TType.DOUBLE:
return CType.DOUBLE;
case TType.I16:
return CType.I16;
case TType.I32:
return CType.I32;
case TType.I64:
return CType.I64;
case TType.STRING:
return CType.BINARY;
case TType.STRUCT:
return CType.STRUCT;
case TType.MAP:
return CType.MAP;
case TType.SET:
return CType.SET;
case TType.LIST:
return CType.LIST;
case TType.VOID:
assert(false, "Invalid type passed.");
}
}
int readVarint32() {
return cast(int)readVarint64();
}
long readVarint64() {
ulong val;
ubyte shift;
ubyte[10] buf = void; // 64 bits / (7 bits/byte) = 10 bytes.
auto bufSize = buf.sizeof;
auto borrowed = trans_.borrow(buf.ptr, bufSize);
ubyte rsize;
if (borrowed) {
// Fast path.
while (true) {
auto bite = borrowed[rsize];
rsize++;
val |= cast(ulong)(bite & 0x7f) << shift;
shift += 7;
if (!(bite & 0x80)) {
trans_.consume(rsize);
return val;
}
// Have to check for invalid data so we don't crash.
if (rsize == buf.sizeof) {
throw new TProtocolException(TProtocolException.Type.INVALID_DATA,
"Variable-length int over 10 bytes.");
}
}
} else {
// Slow path.
while (true) {
ubyte[1] bite;
trans_.readAll(bite);
++rsize;
val |= cast(ulong)(bite[0] & 0x7f) << shift;
shift += 7;
if (!(bite[0] & 0x80)) {
return val;
}
// Might as well check for invalid data on the slow path too.
if (rsize >= buf.sizeof) {
throw new TProtocolException(TProtocolException.Type.INVALID_DATA,
"Variable-length int over 10 bytes.");
}
}
}
}
/*
* Convert from zigzag int to int.
*/
int zigzagToI32(uint n) {
return (n >> 1) ^ -(n & 1);
}
/*
* Convert from zigzag long to long.
*/
long zigzagToI64(ulong n) {
return (n >> 1) ^ -(n & 1);
}
TType getTType(CType type) {
final switch (type) {
case CType.STOP:
return TType.STOP;
case CType.BOOLEAN_FALSE:
return TType.BOOL;
case CType.BOOLEAN_TRUE:
return TType.BOOL;
case CType.BYTE:
return TType.BYTE;
case CType.I16:
return TType.I16;
case CType.I32:
return TType.I32;
case CType.I64:
return TType.I64;
case CType.DOUBLE:
return TType.DOUBLE;
case CType.BINARY:
return TType.STRING;
case CType.LIST:
return TType.LIST;
case CType.SET:
return TType.SET;
case CType.MAP:
return TType.MAP;
case CType.STRUCT:
return TType.STRUCT;
}
}
void checkSize(int size, int limit) {
if (size < 0) {
throw new TProtocolException(TProtocolException.Type.NEGATIVE_SIZE);
} else if (limit > 0 && size > limit) {
throw new TProtocolException(TProtocolException.Type.SIZE_LIMIT);
}
}
enum PROTOCOL_ID = 0x82;
enum VERSION_N = 1;
enum VERSION_MASK = 0b0001_1111;
enum TYPE_MASK = 0b1110_0000;
enum TYPE_BITS = 0b0000_0111;
enum TYPE_SHIFT_AMOUNT = 5;
// Probably need to implement a better stack at some point.
short[] fieldIdStack_;
short lastFieldId_;
TField booleanField_;
bool hasBoolValue_;
bool boolValue_;
Transport trans_;
}
/**
* TCompactProtocol construction helper to avoid having to explicitly specify
* the transport type, i.e. to allow the constructor being called using IFTI
* (see $(LINK2 http://d.puremagic.com/issues/show_bug.cgi?id=6082, D Bugzilla
* enhancement requet 6082)).
*/
TCompactProtocol!Transport tCompactProtocol(Transport)(Transport trans,
int containerSizeLimit = 0, int stringSizeLimit = 0
) if (isTTransport!Transport)
{
return new TCompactProtocol!Transport(trans,
containerSizeLimit, stringSizeLimit);
}
private {
enum CType : ubyte {
STOP = 0x0,
BOOLEAN_TRUE = 0x1,
BOOLEAN_FALSE = 0x2,
BYTE = 0x3,
I16 = 0x4,
I32 = 0x5,
I64 = 0x6,
DOUBLE = 0x7,
BINARY = 0x8,
LIST = 0x9,
SET = 0xa,
MAP = 0xb,
STRUCT = 0xc
}
static assert(CType.max <= 0xf,
"Compact protocol wire type representation must fit into 4 bits.");
}
unittest {
import std.exception;
import thrift.transport.memory;
// Check the message header format.
auto buf = new TMemoryBuffer;
auto compact = tCompactProtocol(buf);
compact.writeMessageBegin(TMessage("foo", TMessageType.CALL, 0));
auto header = new ubyte[7];
buf.readAll(header);
enforce(header == [
130, // Protocol id.
33, // Version/type byte.
0, // Sequence id.
3, 102, 111, 111 // Method name.
]);
}
unittest {
import thrift.internal.test.protocol;
testContainerSizeLimit!(TCompactProtocol!())();
testStringSizeLimit!(TCompactProtocol!())();
}
/**
* TProtocolFactory creating a TCompactProtocol instance for passed in
* transports.
*
* The optional Transports template tuple parameter can be used to specify
* one or more TTransport implementations to specifically instantiate
* TCompactProtocol for. If the actual transport types encountered at
* runtime match one of the transports in the list, a specialized protocol
* instance is created. Otherwise, a generic TTransport version is used.
*/
class TCompactProtocolFactory(Transports...) if (
allSatisfy!(isTTransport, Transports)
) : TProtocolFactory {
///
this(int containerSizeLimit = 0, int stringSizeLimit = 0) {
containerSizeLimit_ = 0;
stringSizeLimit_ = 0;
}
TProtocol getProtocol(TTransport trans) const {
foreach (Transport; TypeTuple!(Transports, TTransport)) {
auto concreteTrans = cast(Transport)trans;
if (concreteTrans) {
return new TCompactProtocol!Transport(concreteTrans);
}
}
throw new TProtocolException(
"Passed null transport to TCompactProtocolFactory.");
}
int containerSizeLimit_;
int stringSizeLimit_;
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,145 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.protocol.processor;
// Use selective import once DMD @@BUG314@@ is fixed.
import std.variant /+ : Variant +/;
import thrift.protocol.base;
import thrift.transport.base;
/**
* A processor is a generic object which operates upon an input stream and
* writes to some output stream.
*
* The definition of this object is loose, though the typical case is for some
* sort of server that either generates responses to an input stream or
* forwards data from one pipe onto another.
*
* An implementation can optionally allow one or more TProcessorEventHandlers
* to be attached, providing an interface to hook custom code into the
* handling process, which can be used e.g. for gathering statistics.
*/
interface TProcessor {
///
bool process(TProtocol iprot, TProtocol oprot,
Variant connectionContext = Variant()
) in {
assert(iprot);
assert(oprot);
}
///
final bool process(TProtocol prot, Variant connectionContext = Variant()) {
return process(prot, prot, connectionContext);
}
}
/**
* Handles events from a processor.
*/
interface TProcessorEventHandler {
/**
* Called before calling other callback methods.
*
* Expected to return some sort of »call context«, which is passed to all
* other callbacks for that function invocation.
*/
Variant createContext(string methodName, Variant connectionContext);
/**
* Called when handling the method associated with a context has been
* finished can be used to perform clean up work.
*/
void deleteContext(Variant callContext, string methodName);
/**
* Called before reading arguments.
*/
void preRead(Variant callContext, string methodName);
/**
* Called between reading arguments and calling the handler.
*/
void postRead(Variant callContext, string methodName);
/**
* Called between calling the handler and writing the response.
*/
void preWrite(Variant callContext, string methodName);
/**
* Called after writing the response.
*/
void postWrite(Variant callContext, string methodName);
/**
* Called when handling a one-way function call is completed successfully.
*/
void onewayComplete(Variant callContext, string methodName);
/**
* Called if the handler throws an undeclared exception.
*/
void handlerError(Variant callContext, string methodName, Exception e);
}
struct TConnectionInfo {
/// The input and output protocols.
TProtocol input;
TProtocol output; /// Ditto.
/// The underlying transport used for the connection
/// This is the transport that was returned by TServerTransport.accept(),
/// and it may be different than the transport pointed to by the input and
/// output protocols.
TTransport transport;
}
interface TProcessorFactory {
/**
* Get the TProcessor to use for a particular connection.
*
* This method is always invoked in the same thread that the connection was
* accepted on, which is always the same thread for all current server
* implementations.
*/
TProcessor getProcessor(ref const(TConnectionInfo) connInfo);
}
/**
* The default processor factory which always returns the same instance.
*/
class TSingletonProcessorFactory : TProcessorFactory {
/**
* Creates a new instance.
*
* Params:
* processor = The processor object to return from getProcessor().
*/
this(TProcessor processor) {
processor_ = processor;
}
override TProcessor getProcessor(ref const(TConnectionInfo) connInfo) {
return processor_;
}
private:
TProcessor processor_;
}

View file

@ -0,0 +1,147 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.server.base;
import std.variant : Variant;
import thrift.protocol.base;
import thrift.protocol.binary;
import thrift.protocol.processor;
import thrift.server.transport.base;
import thrift.transport.base;
import thrift.util.cancellation;
/**
* Base class for all Thrift servers.
*
* By setting the eventHandler property to a TServerEventHandler
* implementation, custom code can be integrated into the processing pipeline,
* which can be used e.g. for gathering statistics.
*/
class TServer {
/**
* Starts serving.
*
* Blocks until the server finishes, i.e. a serious problem occurred or the
* cancellation request has been triggered.
*
* Server implementations are expected to implement cancellation in a best-
* effort way usually, it should be possible to immediately stop accepting
* connections and return after all currently active clients have been
* processed, but this might not be the case for every conceivable
* implementation.
*/
abstract void serve(TCancellation cancellation = null);
/// The server event handler to notify. Null by default.
TServerEventHandler eventHandler;
protected:
this(
TProcessor processor,
TServerTransport serverTransport,
TTransportFactory transportFactory,
TProtocolFactory protocolFactory
) {
this(processor, serverTransport, transportFactory, transportFactory,
protocolFactory, protocolFactory);
}
this(
TProcessorFactory processorFactory,
TServerTransport serverTransport,
TTransportFactory transportFactory,
TProtocolFactory protocolFactory
) {
this(processorFactory, serverTransport, transportFactory, transportFactory,
protocolFactory, protocolFactory);
}
this(
TProcessor processor,
TServerTransport serverTransport,
TTransportFactory inputTransportFactory,
TTransportFactory outputTransportFactory,
TProtocolFactory inputProtocolFactory,
TProtocolFactory outputProtocolFactory
) {
this(new TSingletonProcessorFactory(processor), serverTransport,
inputTransportFactory, outputTransportFactory,
inputProtocolFactory, outputProtocolFactory);
}
this(
TProcessorFactory processorFactory,
TServerTransport serverTransport,
TTransportFactory inputTransportFactory,
TTransportFactory outputTransportFactory,
TProtocolFactory inputProtocolFactory,
TProtocolFactory outputProtocolFactory
) {
import std.exception;
import thrift.base;
enforce(inputTransportFactory,
new TException("Input transport factory must not be null."));
enforce(outputTransportFactory,
new TException("Output transport factory must not be null."));
enforce(inputProtocolFactory,
new TException("Input protocol factory must not be null."));
enforce(outputProtocolFactory,
new TException("Output protocol factory must not be null."));
processorFactory_ = processorFactory;
serverTransport_ = serverTransport;
inputTransportFactory_ = inputTransportFactory;
outputTransportFactory_ = outputTransportFactory;
inputProtocolFactory_ = inputProtocolFactory;
outputProtocolFactory_ = outputProtocolFactory;
}
TProcessorFactory processorFactory_;
TServerTransport serverTransport_;
TTransportFactory inputTransportFactory_;
TTransportFactory outputTransportFactory_;
TProtocolFactory inputProtocolFactory_;
TProtocolFactory outputProtocolFactory_;
}
/**
* Handles events from a TServer core.
*/
interface TServerEventHandler {
/**
* Called before the server starts accepting connections.
*/
void preServe();
/**
* Called when a new client has connected and processing is about to begin.
*/
Variant createContext(TProtocol input, TProtocol output);
/**
* Called when request handling for a client has been finished can be used
* to perform clean up work.
*/
void deleteContext(Variant serverContext, TProtocol input, TProtocol output);
/**
* Called when the processor for a client call is about to be invoked.
*/
void preProcess(Variant serverContext, TTransport transport);
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,181 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.server.simple;
import std.variant : Variant;
import thrift.base;
import thrift.protocol.base;
import thrift.protocol.processor;
import thrift.server.base;
import thrift.server.transport.base;
import thrift.transport.base;
import thrift.util.cancellation;
/**
* The most basic server.
*
* It is single-threaded and after it accepts a connections, it processes
* requests on it until it closes, then waiting for the next connection.
*
* It is not so much of use in production than it is for writing unittests, or
* as an example on how to provide a custom TServer implementation.
*/
class TSimpleServer : TServer {
///
this(
TProcessor processor,
TServerTransport serverTransport,
TTransportFactory transportFactory,
TProtocolFactory protocolFactory
) {
super(processor, serverTransport, transportFactory, protocolFactory);
}
///
this(
TProcessorFactory processorFactory,
TServerTransport serverTransport,
TTransportFactory transportFactory,
TProtocolFactory protocolFactory
) {
super(processorFactory, serverTransport, transportFactory, protocolFactory);
}
///
this(
TProcessor processor,
TServerTransport serverTransport,
TTransportFactory inputTransportFactory,
TTransportFactory outputTransportFactory,
TProtocolFactory inputProtocolFactory,
TProtocolFactory outputProtocolFactory
) {
super(processor, serverTransport, inputTransportFactory,
outputTransportFactory, inputProtocolFactory, outputProtocolFactory);
}
this(
TProcessorFactory processorFactory,
TServerTransport serverTransport,
TTransportFactory inputTransportFactory,
TTransportFactory outputTransportFactory,
TProtocolFactory inputProtocolFactory,
TProtocolFactory outputProtocolFactory
) {
super(processorFactory, serverTransport, inputTransportFactory,
outputTransportFactory, inputProtocolFactory, outputProtocolFactory);
}
override void serve(TCancellation cancellation = null) {
serverTransport_.listen();
if (eventHandler) eventHandler.preServe();
while (true) {
TTransport client;
TTransport inputTransport;
TTransport outputTransport;
TProtocol inputProtocol;
TProtocol outputProtocol;
try {
client = serverTransport_.accept(cancellation);
scope(failure) client.close();
inputTransport = inputTransportFactory_.getTransport(client);
scope(failure) inputTransport.close();
outputTransport = outputTransportFactory_.getTransport(client);
scope(failure) outputTransport.close();
inputProtocol = inputProtocolFactory_.getProtocol(inputTransport);
outputProtocol = outputProtocolFactory_.getProtocol(outputTransport);
} catch (TCancelledException tcx) {
break;
} catch (TTransportException ttx) {
logError("TServerTransport failed on accept: %s", ttx);
continue;
} catch (TException tx) {
logError("Caught TException on accept: %s", tx);
continue;
}
auto info = TConnectionInfo(inputProtocol, outputProtocol, client);
auto processor = processorFactory_.getProcessor(info);
Variant connectionContext;
if (eventHandler) {
connectionContext =
eventHandler.createContext(inputProtocol, outputProtocol);
}
try {
while (true) {
if (eventHandler) {
eventHandler.preProcess(connectionContext, client);
}
if (!processor.process(inputProtocol, outputProtocol,
connectionContext) || !inputProtocol.transport.peek()
) {
// Something went fundamentlly wrong or there is nothing more to
// process, close the connection.
break;
}
}
} catch (TTransportException ttx) {
logError("Client died: %s", ttx);
} catch (Exception e) {
logError("Uncaught exception: %s", e);
}
if (eventHandler) {
eventHandler.deleteContext(connectionContext, inputProtocol,
outputProtocol);
}
try {
inputTransport.close();
} catch (TTransportException ttx) {
logError("Input close failed: %s", ttx);
}
try {
outputTransport.close();
} catch (TTransportException ttx) {
logError("Output close failed: %s", ttx);
}
try {
client.close();
} catch (TTransportException ttx) {
logError("Client close failed: %s", ttx);
}
}
try {
serverTransport_.close();
} catch (TServerTransportException e) {
logError("Server transport failed to close(): %s", e);
}
}
}
unittest {
import thrift.internal.test.server;
testServeCancel!TSimpleServer();
}

View file

@ -0,0 +1,302 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.server.taskpool;
import core.sync.condition;
import core.sync.mutex;
import std.exception : enforce;
import std.parallelism;
import std.variant : Variant;
import thrift.base;
import thrift.protocol.base;
import thrift.protocol.processor;
import thrift.server.base;
import thrift.server.transport.base;
import thrift.transport.base;
import thrift.util.cancellation;
/**
* A server which dispatches client requests to a std.parallelism TaskPool.
*/
class TTaskPoolServer : TServer {
///
this(
TProcessor processor,
TServerTransport serverTransport,
TTransportFactory transportFactory,
TProtocolFactory protocolFactory,
TaskPool taskPool = null
) {
this(processor, serverTransport, transportFactory, transportFactory,
protocolFactory, protocolFactory, taskPool);
}
///
this(
TProcessorFactory processorFactory,
TServerTransport serverTransport,
TTransportFactory transportFactory,
TProtocolFactory protocolFactory,
TaskPool taskPool = null
) {
this(processorFactory, serverTransport, transportFactory, transportFactory,
protocolFactory, protocolFactory, taskPool);
}
///
this(
TProcessor processor,
TServerTransport serverTransport,
TTransportFactory inputTransportFactory,
TTransportFactory outputTransportFactory,
TProtocolFactory inputProtocolFactory,
TProtocolFactory outputProtocolFactory,
TaskPool taskPool = null
) {
this(new TSingletonProcessorFactory(processor), serverTransport,
inputTransportFactory, outputTransportFactory,
inputProtocolFactory, outputProtocolFactory);
}
///
this(
TProcessorFactory processorFactory,
TServerTransport serverTransport,
TTransportFactory inputTransportFactory,
TTransportFactory outputTransportFactory,
TProtocolFactory inputProtocolFactory,
TProtocolFactory outputProtocolFactory,
TaskPool taskPool = null
) {
super(processorFactory, serverTransport, inputTransportFactory,
outputTransportFactory, inputProtocolFactory, outputProtocolFactory);
if (taskPool) {
this.taskPool = taskPool;
} else {
auto ptp = std.parallelism.taskPool;
if (ptp.size > 0) {
taskPool_ = ptp;
} else {
// If the global task pool is empty (default on a single-core machine),
// create a new one with a single worker thread. The rationale for this
// is to avoid that an application which worked fine with no task pool
// explicitly set on the multi-core developer boxes suddenly fails on a
// single-core user machine.
taskPool_ = new TaskPool(1);
taskPool_.isDaemon = true;
}
}
}
override void serve(TCancellation cancellation = null) {
serverTransport_.listen();
if (eventHandler) eventHandler.preServe();
auto queueState = QueueState();
while (true) {
// Check if we can still handle more connections.
if (maxActiveConns) {
synchronized (queueState.mutex) {
while (queueState.activeConns >= maxActiveConns) {
queueState.connClosed.wait();
}
}
}
TTransport client;
TTransport inputTransport;
TTransport outputTransport;
TProtocol inputProtocol;
TProtocol outputProtocol;
try {
client = serverTransport_.accept(cancellation);
scope(failure) client.close();
inputTransport = inputTransportFactory_.getTransport(client);
scope(failure) inputTransport.close();
outputTransport = outputTransportFactory_.getTransport(client);
scope(failure) outputTransport.close();
inputProtocol = inputProtocolFactory_.getProtocol(inputTransport);
outputProtocol = outputProtocolFactory_.getProtocol(outputTransport);
} catch (TCancelledException tce) {
break;
} catch (TTransportException ttx) {
logError("TServerTransport failed on accept: %s", ttx);
continue;
} catch (TException tx) {
logError("Caught TException on accept: %s", tx);
continue;
}
auto info = TConnectionInfo(inputProtocol, outputProtocol, client);
auto processor = processorFactory_.getProcessor(info);
synchronized (queueState.mutex) {
++queueState.activeConns;
}
taskPool_.put(task!worker(queueState, client, inputProtocol,
outputProtocol, processor, eventHandler));
}
// First, stop accepting new connections.
try {
serverTransport_.close();
} catch (TServerTransportException e) {
logError("Server transport failed to close: %s", e);
}
// Then, wait until all active connections are finished.
synchronized (queueState.mutex) {
while (queueState.activeConns > 0) {
queueState.connClosed.wait();
}
}
}
/**
* Sets the task pool to use.
*
* By default, the global std.parallelism taskPool instance is used, which
* might not be appropriate for many applications, e.g. where tuning the
* number of worker threads is desired. (On single-core systems, a private
* task pool with a single thread is used by default, since the global
* taskPool instance has no worker threads then.)
*
* Note: TTaskPoolServer expects that tasks are never dropped from the pool,
* e.g. by calling TaskPool.close() while there are still tasks in the
* queue. If this happens, serve() will never return.
*/
void taskPool(TaskPool pool) @property {
enforce(pool !is null, "Cannot use a null task pool.");
enforce(pool.size > 0, "Cannot use a task pool with no worker threads.");
taskPool_ = pool;
}
/**
* The maximum number of client connections open at the same time. Zero for
* no limit, which is the default.
*
* If this limit is reached, no clients are accept()ed from the server
* transport any longer until another connection has been closed again.
*/
size_t maxActiveConns;
protected:
TaskPool taskPool_;
}
// Cannot be private as worker has to be passed as alias parameter to
// another module.
// private {
/*
* The state of the »connection queue«, i.e. used for keeping track of how
* many client connections are currently processed.
*/
struct QueueState {
/// Protects the queue state.
Mutex mutex;
/// The number of active connections (from the time they are accept()ed
/// until they are closed when the worked task finishes).
size_t activeConns;
/// Signals that the number of active connections has been decreased, i.e.
/// that a connection has been closed.
Condition connClosed;
/// Returns an initialized instance.
static QueueState opCall() {
QueueState q;
q.mutex = new Mutex;
q.connClosed = new Condition(q.mutex);
return q;
}
}
void worker(ref QueueState queueState, TTransport client,
TProtocol inputProtocol, TProtocol outputProtocol,
TProcessor processor, TServerEventHandler eventHandler)
{
scope (exit) {
synchronized (queueState.mutex) {
assert(queueState.activeConns > 0);
--queueState.activeConns;
queueState.connClosed.notifyAll();
}
}
Variant connectionContext;
if (eventHandler) {
connectionContext =
eventHandler.createContext(inputProtocol, outputProtocol);
}
try {
while (true) {
if (eventHandler) {
eventHandler.preProcess(connectionContext, client);
}
if (!processor.process(inputProtocol, outputProtocol,
connectionContext) || !inputProtocol.transport.peek()
) {
// Something went fundamentlly wrong or there is nothing more to
// process, close the connection.
break;
}
}
} catch (TTransportException ttx) {
logError("Client died: %s", ttx);
} catch (Exception e) {
logError("Uncaught exception: %s", e);
}
if (eventHandler) {
eventHandler.deleteContext(connectionContext, inputProtocol,
outputProtocol);
}
try {
inputProtocol.transport.close();
} catch (TTransportException ttx) {
logError("Input close failed: %s", ttx);
}
try {
outputProtocol.transport.close();
} catch (TTransportException ttx) {
logError("Output close failed: %s", ttx);
}
try {
client.close();
} catch (TTransportException ttx) {
logError("Client close failed: %s", ttx);
}
}
// }
unittest {
import thrift.internal.test.server;
testServeCancel!TTaskPoolServer();
}

View file

@ -0,0 +1,215 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.server.threaded;
import core.thread;
import std.variant : Variant;
import thrift.base;
import thrift.protocol.base;
import thrift.protocol.processor;
import thrift.server.base;
import thrift.server.transport.base;
import thrift.transport.base;
import thrift.util.cancellation;
/**
* A simple threaded server which spawns a new thread per connection.
*/
class TThreadedServer : TServer {
///
this(
TProcessor processor,
TServerTransport serverTransport,
TTransportFactory transportFactory,
TProtocolFactory protocolFactory
) {
super(processor, serverTransport, transportFactory, protocolFactory);
}
///
this(
TProcessorFactory processorFactory,
TServerTransport serverTransport,
TTransportFactory transportFactory,
TProtocolFactory protocolFactory
) {
super(processorFactory, serverTransport, transportFactory, protocolFactory);
}
///
this(
TProcessor processor,
TServerTransport serverTransport,
TTransportFactory inputTransportFactory,
TTransportFactory outputTransportFactory,
TProtocolFactory inputProtocolFactory,
TProtocolFactory outputProtocolFactory
) {
super(processor, serverTransport, inputTransportFactory,
outputTransportFactory, inputProtocolFactory, outputProtocolFactory);
}
///
this(
TProcessorFactory processorFactory,
TServerTransport serverTransport,
TTransportFactory inputTransportFactory,
TTransportFactory outputTransportFactory,
TProtocolFactory inputProtocolFactory,
TProtocolFactory outputProtocolFactory
) {
super(processorFactory, serverTransport, inputTransportFactory,
outputTransportFactory, inputProtocolFactory, outputProtocolFactory);
}
override void serve(TCancellation cancellation = null) {
try {
// Start the server listening
serverTransport_.listen();
} catch (TTransportException ttx) {
logError("listen() failed: %s", ttx);
return;
}
if (eventHandler) eventHandler.preServe();
auto workerThreads = new ThreadGroup();
while (true) {
TTransport client;
TTransport inputTransport;
TTransport outputTransport;
TProtocol inputProtocol;
TProtocol outputProtocol;
try {
client = serverTransport_.accept(cancellation);
scope(failure) client.close();
inputTransport = inputTransportFactory_.getTransport(client);
scope(failure) inputTransport.close();
outputTransport = outputTransportFactory_.getTransport(client);
scope(failure) outputTransport.close();
inputProtocol = inputProtocolFactory_.getProtocol(inputTransport);
outputProtocol = outputProtocolFactory_.getProtocol(outputTransport);
} catch (TCancelledException tce) {
break;
} catch (TTransportException ttx) {
logError("TServerTransport failed on accept: %s", ttx);
continue;
} catch (TException tx) {
logError("Caught TException on accept: %s", tx);
continue;
}
auto info = TConnectionInfo(inputProtocol, outputProtocol, client);
auto processor = processorFactory_.getProcessor(info);
auto worker = new WorkerThread(client, inputProtocol, outputProtocol,
processor, eventHandler);
workerThreads.add(worker);
worker.start();
}
try {
serverTransport_.close();
} catch (TServerTransportException e) {
logError("Server transport failed to close: %s", e);
}
workerThreads.joinAll();
}
}
// The worker thread handling a client connection.
private class WorkerThread : Thread {
this(TTransport client, TProtocol inputProtocol, TProtocol outputProtocol,
TProcessor processor, TServerEventHandler eventHandler)
{
client_ = client;
inputProtocol_ = inputProtocol;
outputProtocol_ = outputProtocol;
processor_ = processor;
eventHandler_ = eventHandler;
super(&run);
}
void run() {
Variant connectionContext;
if (eventHandler_) {
connectionContext =
eventHandler_.createContext(inputProtocol_, outputProtocol_);
}
try {
while (true) {
if (eventHandler_) {
eventHandler_.preProcess(connectionContext, client_);
}
if (!processor_.process(inputProtocol_, outputProtocol_,
connectionContext) || !inputProtocol_.transport.peek()
) {
// Something went fundamentlly wrong or there is nothing more to
// process, close the connection.
break;
}
}
} catch (TTransportException ttx) {
logError("Client died: %s", ttx);
} catch (Exception e) {
logError("Uncaught exception: %s", e);
}
if (eventHandler_) {
eventHandler_.deleteContext(connectionContext, inputProtocol_,
outputProtocol_);
}
try {
inputProtocol_.transport.close();
} catch (TTransportException ttx) {
logError("Input close failed: %s", ttx);
}
try {
outputProtocol_.transport.close();
} catch (TTransportException ttx) {
logError("Output close failed: %s", ttx);
}
try {
client_.close();
} catch (TTransportException ttx) {
logError("Client close failed: %s", ttx);
}
}
private:
TTransport client_;
TProtocol inputProtocol_;
TProtocol outputProtocol_;
TProcessor processor_;
TServerEventHandler eventHandler_;
}
unittest {
import thrift.internal.test.server;
testServeCancel!TThreadedServer();
}

View file

@ -0,0 +1,133 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.server.transport.base;
import thrift.base;
import thrift.transport.base;
import thrift.util.cancellation;
/**
* Some kind of I/O device enabling servers to listen for incoming client
* connections and communicate with them via a TTransport interface.
*/
interface TServerTransport {
/**
* Starts listening for server connections.
*
* Just as simliar functions commonly found in socket libraries, this
* function does not block.
*
* If the socket is already listening, nothing happens.
*
* Throws: TServerTransportException if listening failed or the transport
* was already listening.
*/
void listen();
/**
* Closes the server transport, causing it to stop listening.
*
* Throws: TServerTransportException if the transport was not listening.
*/
void close();
/**
* Returns whether the server transport is currently listening.
*/
bool isListening() @property;
/**
* Accepts a client connection and returns an opened TTransport for it,
* never returning null.
*
* Blocks until a client connection is available.
*
* Params:
* cancellation = If triggered, requests the call to stop blocking and
* return with a TCancelledException. Implementations are free to
* ignore this if they cannot provide a reasonable.
*
* Throws: TServerTransportException if accepting failed,
* TCancelledException if it was cancelled.
*/
TTransport accept(TCancellation cancellation = null) out (result) {
assert(result !is null);
}
}
/**
* Server transport exception.
*/
class TServerTransportException : TException {
/**
* Error codes for the various types of exceptions.
*/
enum Type {
///
UNKNOWN,
/// The server socket is not listening, but excepted to be.
NOT_LISTENING,
/// The server socket is already listening, but expected not to be.
ALREADY_LISTENING,
/// An operation on the primary underlying resource, e.g. a socket used
/// for accepting connections, failed.
RESOURCE_FAILED
}
///
this(Type type, string file = __FILE__, size_t line = __LINE__, Throwable next = null) {
string msg = "TTransportException: ";
switch (type) {
case Type.UNKNOWN: msg ~= "Unknown server transport exception"; break;
case Type.NOT_LISTENING: msg ~= "Server transport not listening"; break;
case Type.ALREADY_LISTENING: msg ~= "Server transport already listening"; break;
case Type.RESOURCE_FAILED: msg ~= "An underlying resource failed"; break;
default: msg ~= "(Invalid exception type)"; break;
}
this(msg, type, file, line, next);
}
///
this(string msg, string file = __FILE__, size_t line = __LINE__,
Throwable next = null)
{
this(msg, Type.UNKNOWN, file, line, next);
}
///
this(string msg, Type type, string file = __FILE__, size_t line = __LINE__,
Throwable next = null)
{
super(msg, file, line, next);
type_ = type;
}
///
Type type() const nothrow @property {
return type_;
}
protected:
Type type_;
}

View file

@ -0,0 +1,380 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.server.transport.socket;
import core.thread : dur, Duration, Thread;
import core.stdc.string : strerror;
import std.array : empty;
import std.conv : text, to;
import std.exception : enforce;
import std.socket;
import thrift.base;
import thrift.internal.socket;
import thrift.server.transport.base;
import thrift.transport.base;
import thrift.transport.socket;
import thrift.util.awaitable;
import thrift.util.cancellation;
private alias TServerTransportException STE;
/**
* Server socket implementation of TServerTransport.
*
* Maps to std.socket listen()/accept(); only provides TCP/IP sockets (i.e. no
* Unix sockets) for now, because they are not supported in std.socket.
*/
class TServerSocket : TServerTransport {
/**
* Constructs a new instance.
*
* Params:
* port = The TCP port to listen at (host is always 0.0.0.0).
* sendTimeout = The socket sending timeout.
* recvTimout = The socket receiving timeout.
*/
this(ushort port, Duration sendTimeout = dur!"hnsecs"(0),
Duration recvTimeout = dur!"hnsecs"(0))
{
port_ = port;
sendTimeout_ = sendTimeout;
recvTimeout_ = recvTimeout;
cancellationNotifier_ = new TSocketNotifier;
socketSet_ = new SocketSet;
}
/// The port the server socket listens at.
ushort port() const @property {
return port_;
}
/// The socket sending timeout, zero to block infinitely.
void sendTimeout(Duration sendTimeout) @property {
sendTimeout_ = sendTimeout;
}
/// The socket receiving timeout, zero to block infinitely.
void recvTimeout(Duration recvTimeout) @property {
recvTimeout_ = recvTimeout;
}
/// The maximum number of listening retries if it fails.
void retryLimit(ushort retryLimit) @property {
retryLimit_ = retryLimit;
}
/// The delay between a listening attempt failing and retrying it.
void retryDelay(Duration retryDelay) @property {
retryDelay_ = retryDelay;
}
/// The size of the TCP send buffer, in bytes.
void tcpSendBuffer(int tcpSendBuffer) @property {
tcpSendBuffer_ = tcpSendBuffer;
}
/// The size of the TCP receiving buffer, in bytes.
void tcpRecvBuffer(int tcpRecvBuffer) @property {
tcpRecvBuffer_ = tcpRecvBuffer;
}
/// Whether to listen on IPv6 only, if IPv6 support is detected
/// (default: false).
void ipv6Only(bool value) @property {
ipv6Only_ = value;
}
override void listen() {
enforce(!isListening, new STE(STE.Type.ALREADY_LISTENING));
serverSocket_ = makeSocketAndListen(port_, ACCEPT_BACKLOG, retryLimit_,
retryDelay_, tcpSendBuffer_, tcpRecvBuffer_, ipv6Only_);
}
override void close() {
enforce(isListening, new STE(STE.Type.NOT_LISTENING));
serverSocket_.shutdown(SocketShutdown.BOTH);
serverSocket_.close();
serverSocket_ = null;
}
override bool isListening() @property {
return serverSocket_ !is null;
}
/// Number of connections listen() backlogs.
enum ACCEPT_BACKLOG = 1024;
override TTransport accept(TCancellation cancellation = null) {
enforce(isListening, new STE(STE.Type.NOT_LISTENING));
if (cancellation) cancellationNotifier_.attach(cancellation.triggering);
scope (exit) if (cancellation) cancellationNotifier_.detach();
// Too many EINTRs is a fault condition and would need to be handled
// manually by our caller, but we can tolerate a certain number.
enum MAX_EINTRS = 10;
uint numEintrs;
while (true) {
socketSet_.reset();
socketSet_.add(serverSocket_);
socketSet_.add(cancellationNotifier_.socket);
auto ret = Socket.select(socketSet_, null, null);
enforce(ret != 0, new STE("Socket.select() returned 0.",
STE.Type.RESOURCE_FAILED));
if (ret < 0) {
// Select itself failed, check if it was just due to an interrupted
// syscall.
if (getSocketErrno() == INTERRUPTED_ERRNO) {
if (numEintrs++ < MAX_EINTRS) {
continue;
} else {
throw new STE("Socket.select() was interrupted by a signal (EINTR) " ~
"more than " ~ to!string(MAX_EINTRS) ~ " times.",
STE.Type.RESOURCE_FAILED
);
}
}
throw new STE("Unknown error on Socket.select(): " ~
socketErrnoString(getSocketErrno()), STE.Type.RESOURCE_FAILED);
} else {
// Check for a ping on the interrupt socket.
if (socketSet_.isSet(cancellationNotifier_.socket)) {
cancellation.throwIfTriggered();
}
// Check for the actual server socket having a connection waiting.
if (socketSet_.isSet(serverSocket_)) {
break;
}
}
}
try {
auto client = createTSocket(serverSocket_.accept());
client.sendTimeout = sendTimeout_;
client.recvTimeout = recvTimeout_;
return client;
} catch (SocketException e) {
throw new STE("Unknown error on accepting: " ~ to!string(e),
STE.Type.RESOURCE_FAILED);
}
}
protected:
/**
* Allows derived classes to create a different TSocket type.
*/
TSocket createTSocket(Socket socket) {
return new TSocket(socket);
}
private:
ushort port_;
Duration sendTimeout_;
Duration recvTimeout_;
ushort retryLimit_;
Duration retryDelay_;
uint tcpSendBuffer_;
uint tcpRecvBuffer_;
bool ipv6Only_;
Socket serverSocket_;
TSocketNotifier cancellationNotifier_;
// Keep socket set between accept() calls to avoid reallocating.
SocketSet socketSet_;
}
Socket makeSocketAndListen(ushort port, int backlog, ushort retryLimit,
Duration retryDelay, uint tcpSendBuffer = 0, uint tcpRecvBuffer = 0,
bool ipv6Only = false
) {
Address localAddr;
try {
// null represents the wildcard address.
auto addrInfos = getAddressInfo(null, to!string(port),
AddressInfoFlags.PASSIVE, SocketType.STREAM, ProtocolType.TCP);
foreach (i, ai; addrInfos) {
// Prefer to bind to IPv6 addresses, because then IPv4 is listened to as
// well, but not the other way round.
if (ai.family == AddressFamily.INET6 || i == (addrInfos.length - 1)) {
localAddr = ai.address;
break;
}
}
} catch (Exception e) {
throw new STE("Could not determine local address to listen on.",
STE.Type.RESOURCE_FAILED, __FILE__, __LINE__, e);
}
Socket socket;
try {
socket = new Socket(localAddr.addressFamily, SocketType.STREAM,
ProtocolType.TCP);
} catch (SocketException e) {
throw new STE("Could not create accepting socket: " ~ to!string(e),
STE.Type.RESOURCE_FAILED);
}
try {
socket.setOption(SocketOptionLevel.IPV6, SocketOption.IPV6_V6ONLY, ipv6Only);
} catch (SocketException e) {
// This is somewhat expected on older systems (e.g. pre-Vista Windows),
// which do not support the IPV6_V6ONLY flag yet. Racy flag just to avoid
// log spew in unit tests.
shared static warned = false;
if (!warned) {
logError("Could not set IPV6_V6ONLY socket option: %s", e);
warned = true;
}
}
alias SocketOptionLevel.SOCKET lvlSock;
// Prevent 2 maximum segement lifetime delay on accept.
try {
socket.setOption(lvlSock, SocketOption.REUSEADDR, true);
} catch (SocketException e) {
throw new STE("Could not set REUSEADDR socket option: " ~ to!string(e),
STE.Type.RESOURCE_FAILED);
}
// Set TCP buffer sizes.
if (tcpSendBuffer > 0) {
try {
socket.setOption(lvlSock, SocketOption.SNDBUF, tcpSendBuffer);
} catch (SocketException e) {
throw new STE("Could not set socket send buffer size: " ~ to!string(e),
STE.Type.RESOURCE_FAILED);
}
}
if (tcpRecvBuffer > 0) {
try {
socket.setOption(lvlSock, SocketOption.RCVBUF, tcpRecvBuffer);
} catch (SocketException e) {
throw new STE("Could not set receive send buffer size: " ~ to!string(e),
STE.Type.RESOURCE_FAILED);
}
}
// Turn linger off to avoid blocking on socket close.
try {
Linger l;
l.on = 0;
l.time = 0;
socket.setOption(lvlSock, SocketOption.LINGER, l);
} catch (SocketException e) {
throw new STE("Could not disable socket linger: " ~ to!string(e),
STE.Type.RESOURCE_FAILED);
}
// Set TCP_NODELAY.
try {
socket.setOption(SocketOptionLevel.TCP, SocketOption.TCP_NODELAY, true);
} catch (SocketException e) {
throw new STE("Could not disable Nagle's algorithm: " ~ to!string(e),
STE.Type.RESOURCE_FAILED);
}
ushort retries;
while (true) {
try {
socket.bind(localAddr);
break;
} catch (SocketException) {}
// If bind() worked, we breaked outside the loop above.
retries++;
if (retries < retryLimit) {
Thread.sleep(retryDelay);
} else {
throw new STE(text("Could not bind to address: ", localAddr),
STE.Type.RESOURCE_FAILED);
}
}
socket.listen(backlog);
return socket;
}
unittest {
// Test interrupt().
{
auto sock = new TServerSocket(0);
sock.listen();
scope (exit) sock.close();
auto cancellation = new TCancellationOrigin;
auto intThread = new Thread({
// Sleep for a bit until the socket is accepting.
Thread.sleep(dur!"msecs"(50));
cancellation.trigger();
});
intThread.start();
import std.exception;
assertThrown!TCancelledException(sock.accept(cancellation));
}
// Test receive() timeout on accepted client sockets.
{
immutable port = 11122;
auto timeout = dur!"msecs"(500);
auto serverSock = new TServerSocket(port, timeout, timeout);
serverSock.listen();
scope (exit) serverSock.close();
auto clientSock = new TSocket("127.0.0.1", port);
clientSock.open();
scope (exit) clientSock.close();
shared bool hasTimedOut;
auto recvThread = new Thread({
auto sock = serverSock.accept();
ubyte[1] data;
try {
sock.read(data);
} catch (TTransportException e) {
if (e.type == TTransportException.Type.TIMED_OUT) {
hasTimedOut = true;
} else {
import std.stdio;
stderr.writeln(e);
}
}
});
recvThread.isDaemon = true;
recvThread.start();
// Wait for the timeout, with a little bit of spare time.
Thread.sleep(timeout + dur!"msecs"(50));
enforce(hasTimedOut,
"Client socket receive() blocked for longer than recvTimeout.");
}
}

View file

@ -0,0 +1,88 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.server.transport.ssl;
import std.datetime : Duration;
import std.exception : enforce;
import std.socket : Socket;
import thrift.server.transport.socket;
import thrift.transport.base;
import thrift.transport.socket;
import thrift.transport.ssl;
/**
* A server transport implementation using SSL-encrypted sockets.
*
* Note:
* On Posix systems which do not have the BSD-specific SO_NOSIGPIPE flag, you
* might want to ignore the SIGPIPE signal, as OpenSSL might try to write to
* a closed socket if the peer disconnects abruptly:
* ---
* import core.stdc.signal;
* import core.sys.posix.signal;
* signal(SIGPIPE, SIG_IGN);
* ---
*
* See: thrift.transport.ssl.
*/
class TSSLServerSocket : TServerSocket {
/**
* Creates a new TSSLServerSocket.
*
* Params:
* port = The port on which to listen.
* sslContext = The TSSLContext to use for creating client
* sockets. Must be in server-side mode.
*/
this(ushort port, TSSLContext sslContext) {
super(port);
setSSLContext(sslContext);
}
/**
* Creates a new TSSLServerSocket.
*
* Params:
* port = The port on which to listen.
* sendTimeout = The send timeout to set on the client sockets.
* recvTimeout = The receive timeout to set on the client sockets.
* sslContext = The TSSLContext to use for creating client
* sockets. Must be in server-side mode.
*/
this(ushort port, Duration sendTimeout, Duration recvTimeout,
TSSLContext sslContext)
{
super(port, sendTimeout, recvTimeout);
setSSLContext(sslContext);
}
protected:
override TSocket createTSocket(Socket socket) {
return new TSSLSocket(sslContext_, socket);
}
private:
void setSSLContext(TSSLContext sslContext) {
enforce(sslContext.serverSide, new TTransportException(
"Need server-side SSL socket factory for TSSLServerSocket"));
sslContext_ = sslContext;
}
TSSLContext sslContext_;
}

View file

@ -0,0 +1,370 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.transport.base;
import core.stdc.string : strerror;
import std.conv : text;
import thrift.base;
/**
* An entity data can be read from and/or written to.
*
* A TTransport implementation may capable of either reading or writing, but
* not necessarily both.
*/
interface TTransport {
/**
* Whether this transport is open.
*
* If a transport is closed, it can be opened by calling open(), and vice
* versa for close().
*
* While a transport should always be open when trying to read/write data,
* the related functions do not necessarily fail when called for a closed
* transport. Situations like this could occur e.g. with a wrapper
* transport which buffers data when the underlying transport has already
* been closed (possibly because the connection was abruptly closed), but
* there is still data left to be read in the buffers. This choice has been
* made to simplify transport implementations, in terms of both code
* complexity and runtime overhead.
*/
bool isOpen() @property;
/**
* Tests whether there is more data to read or if the remote side is
* still open.
*
* A typical use case would be a server checking if it should process
* another request on the transport.
*/
bool peek();
/**
* Opens the transport for communications.
*
* If the transport is already open, nothing happens.
*
* Throws: TTransportException if opening fails.
*/
void open();
/**
* Closes the transport.
*
* If the transport is not open, nothing happens.
*
* Throws: TTransportException if closing fails.
*/
void close();
/**
* Attempts to fill the given buffer by reading data.
*
* For potentially blocking data sources (e.g. sockets), read() will only
* block if no data is available at all. If there is some data available,
* but waiting for new data to arrive would be required to fill the whole
* buffer, the readily available data will be immediately returned use
* readAll() if you want to wait until the whole buffer is filled.
*
* Params:
* buf = Slice to use as buffer.
*
* Returns: How many bytes were actually read
*
* Throws: TTransportException if an error occurs.
*/
size_t read(ubyte[] buf);
/**
* Fills the given buffer by reading data into it, failing if not enough
* data is available.
*
* Params:
* buf = Slice to use as buffer.
*
* Throws: TTransportException if insufficient data is available or reading
* fails altogether.
*/
void readAll(ubyte[] buf);
/**
* Must be called by clients when read is completed.
*
* Implementations can choose to perform a transport-specific action, e.g.
* logging the request to a file.
*
* Returns: The number of bytes read if available, 0 otherwise.
*/
size_t readEnd();
/**
* Writes the passed slice of data.
*
* Note: You must call flush() to ensure the data is actually written,
* and available to be read back in the future. Destroying a TTransport
* object does not automatically flush pending data if you destroy a
* TTransport object with written but unflushed data, that data may be
* discarded.
*
* Params:
* buf = Slice of data to write.
*
* Throws: TTransportException if an error occurs.
*/
void write(in ubyte[] buf);
/**
* Must be called by clients when write is completed.
*
* Implementations can choose to perform a transport-specific action, e.g.
* logging the request to a file.
*
* Returns: The number of bytes written if available, 0 otherwise.
*/
size_t writeEnd();
/**
* Flushes any pending data to be written.
*
* Must be called before destruction to ensure writes are actually complete,
* otherwise pending data may be discarded. Typically used with buffered
* transport mechanisms.
*
* Throws: TTransportException if an error occurs.
*/
void flush();
/**
* Attempts to return a slice of <code>len</code> bytes of incoming data,
* possibly copied into buf, not consuming them (i.e.: a later read will
* return the same data).
*
* This method is meant to support protocols that need to read variable-
* length fields. They can attempt to borrow the maximum amount of data that
* they will need, then <code>consume()</code> what they actually use. Some
* transports will not support this method and others will fail occasionally,
* so protocols must be prepared to fall back to <code>read()</code> if
* borrow fails.
*
* The transport must be open when calling this.
*
* Params:
* buf = A buffer where the data can be stored if needed, or null to
* indicate that the caller is not supplying storage, but would like a
* slice of an internal buffer, if available.
* len = The number of bytes to borrow.
*
* Returns: If the borrow succeeds, a slice containing the borrowed data,
* null otherwise. The slice will be at least as long as requested, but
* may be longer if the returned slice points into an internal buffer
* rather than buf.
*
* Throws: TTransportException if an error occurs.
*/
const(ubyte)[] borrow(ubyte* buf, size_t len) out (result) {
// FIXME: Commented out because len gets corrupted in
// thrift.transport.memory borrow() unittest.
version(none) assert(result is null || result.length >= len,
"Buffer returned by borrow() too short.");
}
/**
* Remove len bytes from the transport. This must always follow a borrow
* of at least len bytes, and should always succeed.
*
* The transport must be open when calling this.
*
* Params:
* len = Number of bytes to consume.
*
* Throws: TTransportException if an error occurs.
*/
void consume(size_t len);
}
/**
* Provides basic fall-back implementations of the TTransport interface.
*/
class TBaseTransport : TTransport {
override bool isOpen() @property {
return false;
}
override bool peek() {
return isOpen;
}
override void open() {
throw new TTransportException("Cannot open TBaseTransport.",
TTransportException.Type.NOT_IMPLEMENTED);
}
override void close() {
throw new TTransportException("Cannot close TBaseTransport.",
TTransportException.Type.NOT_IMPLEMENTED);
}
override size_t read(ubyte[] buf) {
throw new TTransportException("Cannot read from a TBaseTransport.",
TTransportException.Type.NOT_IMPLEMENTED);
}
override void readAll(ubyte[] buf) {
size_t have;
while (have < buf.length) {
size_t get = read(buf[have..$]);
if (get <= 0) {
throw new TTransportException(text("Could not readAll() ", buf.length,
" bytes as no more data was available after ", have, " bytes."),
TTransportException.Type.END_OF_FILE);
}
have += get;
}
}
override size_t readEnd() {
// Do nothing by default, not needed by all implementations.
return 0;
}
override void write(in ubyte[] buf) {
throw new TTransportException("Cannot write to a TBaseTransport.",
TTransportException.Type.NOT_IMPLEMENTED);
}
override size_t writeEnd() {
// Do nothing by default, not needed by all implementations.
return 0;
}
override void flush() {
// Do nothing by default, not needed by all implementations.
}
override const(ubyte)[] borrow(ubyte* buf, size_t len) {
// borrow() is allowed to fail anyway, so just return null.
return null;
}
override void consume(size_t len) {
throw new TTransportException("Cannot consume from a TBaseTransport.",
TTransportException.Type.NOT_IMPLEMENTED);
}
protected:
this() {}
}
/**
* Makes a TTransport which wraps a given source transport in some way.
*
* A common use case is inside server implementations, where the raw client
* connections accepted from e.g. TServerSocket need to be wrapped into
* buffered or compressed transports.
*/
class TTransportFactory {
/**
* Default implementation does nothing, just returns the transport given.
*/
TTransport getTransport(TTransport trans) {
return trans;
}
}
/**
* Transport factory for transports which simply wrap an underlying TTransport
* without requiring additional configuration.
*/
class TWrapperTransportFactory(T) if (
is(T : TTransport) && __traits(compiles, new T(TTransport.init))
) : TTransportFactory {
override T getTransport(TTransport trans) {
return new T(trans);
}
}
/**
* Transport-level exception.
*/
class TTransportException : TException {
/**
* Error codes for the various types of exceptions.
*/
enum Type {
UNKNOWN, ///
NOT_OPEN, ///
TIMED_OUT, ///
END_OF_FILE, ///
INTERRUPTED, ///
BAD_ARGS, ///
CORRUPTED_DATA, ///
INTERNAL_ERROR, ///
NOT_IMPLEMENTED ///
}
///
this(Type type, string file = __FILE__, size_t line = __LINE__, Throwable next = null) {
static string msgForType(Type type) {
switch (type) {
case Type.UNKNOWN: return "Unknown transport exception";
case Type.NOT_OPEN: return "Transport not open";
case Type.TIMED_OUT: return "Timed out";
case Type.END_OF_FILE: return "End of file";
case Type.INTERRUPTED: return "Interrupted";
case Type.BAD_ARGS: return "Invalid arguments";
case Type.CORRUPTED_DATA: return "Corrupted Data";
case Type.INTERNAL_ERROR: return "Internal error";
case Type.NOT_IMPLEMENTED: return "Not implemented";
default: return "(Invalid exception type)";
}
}
this(msgForType(type), type, file, line, next);
}
///
this(string msg, string file = __FILE__, size_t line = __LINE__,
Throwable next = null)
{
this(msg, Type.UNKNOWN, file, line, next);
}
///
this(string msg, Type type, string file = __FILE__, size_t line = __LINE__,
Throwable next = null)
{
super(msg, file, line, next);
type_ = type;
}
///
Type type() const nothrow @property {
return type_;
}
protected:
Type type_;
}
/**
* Meta-programming helper returning whether the passed type is a TTransport
* implementation.
*/
template isTTransport(T) {
enum isTTransport = is(T : TTransport);
}

View file

@ -0,0 +1,215 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.transport.buffered;
import std.algorithm : min;
import std.array : empty;
import std.exception : enforce;
import thrift.transport.base;
/**
* Wraps another transport and buffers reads and writes until the internal
* buffers are exhausted, at which point new data is fetched resp. the
* accumulated data is written out at once.
*/
final class TBufferedTransport : TBaseTransport {
/**
* Constructs a new instance, using the default buffer sizes.
*
* Params:
* transport = The underlying transport to wrap.
*/
this(TTransport transport) {
this(transport, DEFAULT_BUFFER_SIZE);
}
/**
* Constructs a new instance, using the specified buffer size.
*
* Params:
* transport = The underlying transport to wrap.
* bufferSize = The size of the read and write buffers to use, in bytes.
*/
this(TTransport transport, size_t bufferSize) {
this(transport, bufferSize, bufferSize);
}
/**
* Constructs a new instance, using the specified buffer size.
*
* Params:
* transport = The underlying transport to wrap.
* readBufferSize = The size of the read buffer to use, in bytes.
* writeBufferSize = The size of the write buffer to use, in bytes.
*/
this(TTransport transport, size_t readBufferSize, size_t writeBufferSize) {
transport_ = transport;
readBuffer_ = new ubyte[readBufferSize];
writeBuffer_ = new ubyte[writeBufferSize];
writeAvail_ = writeBuffer_;
}
/// The default size of the read/write buffers, in bytes.
enum int DEFAULT_BUFFER_SIZE = 512;
override bool isOpen() @property {
return transport_.isOpen();
}
override bool peek() {
if (readAvail_.empty) {
// If there is nothing available to read, see if we can get something
// from the underlying transport.
auto bytesRead = transport_.read(readBuffer_);
readAvail_ = readBuffer_[0 .. bytesRead];
}
return !readAvail_.empty;
}
override void open() {
transport_.open();
}
override void close() {
if (!isOpen) return;
flush();
transport_.close();
}
override size_t read(ubyte[] buf) {
if (readAvail_.empty) {
// No data left in our buffer, fetch some from the underlying transport.
if (buf.length > readBuffer_.length) {
// If the amount of data requested is larger than our reading buffer,
// directly read to the passed buffer. This probably doesn't occur too
// often in practice (and even if it does, the underlying transport
// probably cannot fulfill the request at once anyway), but it can't
// harm to try…
return transport_.read(buf);
}
auto bytesRead = transport_.read(readBuffer_);
readAvail_ = readBuffer_[0 .. bytesRead];
}
// Hand over whatever we have.
auto give = min(readAvail_.length, buf.length);
buf[0 .. give] = readAvail_[0 .. give];
readAvail_ = readAvail_[give .. $];
return give;
}
/**
* Shortcut version of readAll.
*/
override void readAll(ubyte[] buf) {
if (readAvail_.length >= buf.length) {
buf[] = readAvail_[0 .. buf.length];
readAvail_ = readAvail_[buf.length .. $];
return;
}
super.readAll(buf);
}
override void write(in ubyte[] buf) {
if (writeAvail_.length >= buf.length) {
// If the data fits in the buffer, just save it there.
writeAvail_[0 .. buf.length] = buf;
writeAvail_ = writeAvail_[buf.length .. $];
return;
}
// We have to decide if we copy data from buf to our internal buffer, or
// just directly write them out. The same considerations about avoiding
// syscalls as for C++ apply here.
auto bytesAvail = writeAvail_.ptr - writeBuffer_.ptr;
if ((bytesAvail + buf.length >= 2 * writeBuffer_.length) || (bytesAvail == 0)) {
// We would immediately need two syscalls anyway (or we don't have
// anything) in our buffer to write, so just write out both buffers.
if (bytesAvail > 0) {
transport_.write(writeBuffer_[0 .. bytesAvail]);
writeAvail_ = writeBuffer_;
}
transport_.write(buf);
return;
}
// Fill up our internal buffer for a write.
writeAvail_[] = buf[0 .. writeAvail_.length];
auto left = buf[writeAvail_.length .. $];
transport_.write(writeBuffer_);
// Copy the rest into our buffer.
writeBuffer_[0 .. left.length] = left[];
writeAvail_ = writeBuffer_[left.length .. $];
}
override void flush() {
// Write out any data waiting in the write buffer.
auto bytesAvail = writeAvail_.ptr - writeBuffer_.ptr;
if (bytesAvail > 0) {
// Note that we reset writeAvail_ prior to calling the underlying protocol
// to make sure the buffer is cleared even if the transport throws an
// exception.
writeAvail_ = writeBuffer_;
transport_.write(writeBuffer_[0 .. bytesAvail]);
}
// Flush the underlying transport.
transport_.flush();
}
override const(ubyte)[] borrow(ubyte* buf, size_t len) {
if (len <= readAvail_.length) {
return readAvail_;
}
return null;
}
override void consume(size_t len) {
enforce(len <= readBuffer_.length, new TTransportException(
"Invalid consume length.", TTransportException.Type.BAD_ARGS));
readAvail_ = readAvail_[len .. $];
}
/**
* The wrapped transport.
*/
TTransport underlyingTransport() @property {
return transport_;
}
private:
TTransport transport_;
ubyte[] readBuffer_;
ubyte[] writeBuffer_;
ubyte[] readAvail_;
ubyte[] writeAvail_;
}
/**
* Wraps given transports into TBufferedTransports.
*/
alias TWrapperTransportFactory!TBufferedTransport TBufferedTransportFactory;

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,334 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.transport.framed;
import core.bitop : bswap;
import std.algorithm : min;
import std.array : empty;
import std.exception : enforce;
import thrift.transport.base;
/**
* Framed transport.
*
* All writes go into an in-memory buffer until flush is called, at which point
* the transport writes the length of the entire binary chunk followed by the
* data payload. The receiver on the other end then performs a single
* »fixed-length« read to get the whole message off the wire.
*/
final class TFramedTransport : TBaseTransport {
/**
* Constructs a new framed transport.
*
* Params:
* transport = The underlying transport to wrap.
*/
this(TTransport transport) {
transport_ = transport;
}
/**
* Returns the wrapped transport.
*/
TTransport underlyingTransport() @property {
return transport_;
}
override bool isOpen() @property {
return transport_.isOpen;
}
override bool peek() {
return rBuf_.length > 0 || transport_.peek();
}
override void open() {
transport_.open();
}
override void close() {
flush();
transport_.close();
}
/**
* Attempts to read data into the given buffer, stopping when the buffer is
* exhausted or the frame end is reached.
*
* TODO: Contrary to the C++ implementation, this never does cross-frame
* reads is there actually a valid use case for that?
*
* Params:
* buf = Slice to use as buffer.
*
* Returns: How many bytes were actually read.
*
* Throws: TTransportException if an error occurs.
*/
override size_t read(ubyte[] buf) {
// If the buffer is empty, read a new frame off the wire.
if (rBuf_.empty) {
bool gotFrame = readFrame();
if (!gotFrame) return 0;
}
auto size = min(rBuf_.length, buf.length);
buf[0..size] = rBuf_[0..size];
rBuf_ = rBuf_[size..$];
return size;
}
override void write(in ubyte[] buf) {
wBuf_ ~= buf;
}
override void flush() {
if (wBuf_.empty) return;
// Properly reset the write buffer even some of the protocol operations go
// wrong.
scope (exit) {
wBuf_.length = 0;
wBuf_.assumeSafeAppend();
}
int len = bswap(cast(int)wBuf_.length);
transport_.write(cast(ubyte[])(&len)[0..1]);
transport_.write(wBuf_);
transport_.flush();
}
override const(ubyte)[] borrow(ubyte* buf, size_t len) {
if (len <= rBuf_.length) {
return rBuf_;
} else {
// Don't try attempting cross-frame borrows, trying that does not make
// much sense anyway.
return null;
}
}
override void consume(size_t len) {
enforce(len <= rBuf_.length, new TTransportException(
"Invalid consume length", TTransportException.Type.BAD_ARGS));
rBuf_ = rBuf_[len .. $];
}
private:
bool readFrame() {
// Read the size of the next frame. We can't use readAll() since that
// always throws an exception on EOF, but want to throw an exception only
// if EOF occurs after partial size data.
int size;
size_t size_read;
while (size_read < size.sizeof) {
auto data = (cast(ubyte*)&size)[size_read..size.sizeof];
auto read = transport_.read(data);
if (read == 0) {
if (size_read == 0) {
// EOF before any data was read.
return false;
} else {
// EOF after a partial frame header illegal.
throw new TTransportException(
"No more data to read after partial frame header",
TTransportException.Type.END_OF_FILE
);
}
}
size_read += read;
}
size = bswap(size);
enforce(size >= 0, new TTransportException("Frame size has negative value",
TTransportException.Type.CORRUPTED_DATA));
// TODO: Benchmark this.
rBuf_.length = size;
rBuf_.assumeSafeAppend();
transport_.readAll(rBuf_);
return true;
}
TTransport transport_;
ubyte[] rBuf_;
ubyte[] wBuf_;
}
/**
* Wraps given transports into TFramedTransports.
*/
alias TWrapperTransportFactory!TFramedTransport TFramedTransportFactory;
version (unittest) {
import std.random : Mt19937, uniform;
import thrift.transport.memory;
}
// Some basic random testing, always starting with the same seed for
// deterministic unit test results more tests in transport_test.
unittest {
auto randGen = Mt19937(42);
// 32 kiB of data to work with.
auto data = new ubyte[1 << 15];
foreach (ref b; data) {
b = uniform!"[]"(cast(ubyte)0, cast(ubyte)255, randGen);
}
// Generate a list of chunk sizes to split the data into. A uniform
// distribution is not quite realistic, but std.random doesn't have anything
// else yet.
enum MAX_FRAME_LENGTH = 512;
auto chunkSizesList = new size_t[][2];
foreach (ref chunkSizes; chunkSizesList) {
size_t sum;
while (true) {
auto curLen = uniform(0, MAX_FRAME_LENGTH, randGen);
sum += curLen;
if (sum > data.length) break;
chunkSizes ~= curLen;
}
}
chunkSizesList ~= [data.length]; // Also test whole chunk at once.
// Test writing data.
{
foreach (chunkSizes; chunkSizesList) {
auto buf = new TMemoryBuffer;
auto framed = new TFramedTransport(buf);
auto remainingData = data;
foreach (chunkSize; chunkSizes) {
framed.write(remainingData[0..chunkSize]);
remainingData = remainingData[chunkSize..$];
}
framed.flush();
auto writtenData = data[0..($ - remainingData.length)];
auto actualData = buf.getContents();
// Check frame size.
int frameSize = bswap((cast(int[])(actualData[0..int.sizeof]))[0]);
enforce(frameSize == writtenData.length);
// Check actual data.
enforce(actualData[int.sizeof..$] == writtenData);
}
}
// Test reading data.
{
foreach (chunkSizes; chunkSizesList) {
auto buf = new TMemoryBuffer;
auto size = bswap(cast(int)data.length);
buf.write(cast(ubyte[])(&size)[0..1]);
buf.write(data);
auto framed = new TFramedTransport(buf);
ubyte[] readData;
readData.reserve(data.length);
foreach (chunkSize; chunkSizes) {
// This should work with read because we have one huge frame.
auto oldReadLen = readData.length;
readData.length += chunkSize;
framed.read(readData[oldReadLen..$]);
}
enforce(readData == data[0..readData.length]);
}
}
// Test combined reading/writing of multiple frames.
foreach (flushProbability; [1, 2, 4, 8, 16, 32]) {
foreach (chunkSizes; chunkSizesList) {
auto buf = new TMemoryBuffer;
auto framed = new TFramedTransport(buf);
size_t[] frameSizes;
// Write the data.
size_t frameSize;
auto remainingData = data;
foreach (chunkSize; chunkSizes) {
framed.write(remainingData[0..chunkSize]);
remainingData = remainingData[chunkSize..$];
frameSize += chunkSize;
if (frameSize > 0 && uniform(0, flushProbability, randGen) == 0) {
frameSizes ~= frameSize;
frameSize = 0;
framed.flush();
}
}
if (frameSize > 0) {
frameSizes ~= frameSize;
frameSize = 0;
framed.flush();
}
// Read it back.
auto readData = new ubyte[data.length - remainingData.length];
auto remainToRead = readData;
foreach (fSize; frameSizes) {
// We are exploiting an implementation detail of TFramedTransport:
// The read buffer starts empty and it will never return more than one
// frame per read, so by just requesting all of the data, we should
// always get exactly one frame.
auto got = framed.read(remainToRead);
enforce(got == fSize);
remainToRead = remainToRead[fSize..$];
}
enforce(remainToRead.empty);
enforce(readData == data[0..readData.length]);
}
}
}
// Test flush()ing an empty buffer.
unittest {
auto buf = new TMemoryBuffer();
auto framed = new TFramedTransport(buf);
immutable out1 = [0, 0, 0, 1, 'a'];
immutable out2 = [0, 0, 0, 1, 'a', 0, 0, 0, 2, 'b', 'c'];
framed.flush();
enforce(buf.getContents() == []);
framed.flush();
framed.flush();
enforce(buf.getContents() == []);
framed.write(cast(ubyte[])"a");
enforce(buf.getContents() == []);
framed.flush();
enforce(buf.getContents() == out1);
framed.flush();
framed.flush();
enforce(buf.getContents() == out1);
framed.write(cast(ubyte[])"bc");
enforce(buf.getContents() == out1);
framed.flush();
enforce(buf.getContents() == out2);
framed.flush();
framed.flush();
enforce(buf.getContents() == out2);
}

View file

@ -0,0 +1,459 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/**
* HTTP tranpsort implementation, modelled after the C++ one.
*
* Unfortunately, libcurl is quite heavyweight and supports only client-side
* applications. This is an implementation of the basic HTTP/1.1 parts
* supporting HTTP 100 Continue, chunked transfer encoding, keepalive, etc.
*/
module thrift.transport.http;
import std.algorithm : canFind, countUntil, endsWith, findSplit, min, startsWith;
import std.ascii : toLower;
import std.array : empty;
import std.conv : parse, to;
import std.datetime : Clock, UTC;
import std.string : stripLeft;
import thrift.base : VERSION;
import thrift.transport.base;
import thrift.transport.memory;
import thrift.transport.socket;
/**
* Base class for both client- and server-side HTTP transports.
*/
abstract class THttpTransport : TBaseTransport {
this(TTransport transport) {
transport_ = transport;
readHeaders_ = true;
httpBuf_ = new ubyte[HTTP_BUFFER_SIZE];
httpBufRemaining_ = httpBuf_[0 .. 0];
readBuffer_ = new TMemoryBuffer;
writeBuffer_ = new TMemoryBuffer;
}
override bool isOpen() {
return transport_.isOpen();
}
override bool peek() {
return transport_.peek();
}
override void open() {
transport_.open();
}
override void close() {
transport_.close();
}
override size_t read(ubyte[] buf) {
if (!readBuffer_.peek()) {
readBuffer_.reset();
if (!refill()) return 0;
if (readHeaders_) {
readHeaders();
}
size_t got;
if (chunked_) {
got = readChunked();
} else {
got = readContent(contentLength_);
}
readHeaders_ = true;
if (got == 0) return 0;
}
return readBuffer_.read(buf);
}
override size_t readEnd() {
// Read any pending chunked data (footers etc.)
if (chunked_) {
while (!chunkedDone_) {
readChunked();
}
}
return 0;
}
override void write(in ubyte[] buf) {
writeBuffer_.write(buf);
}
override void flush() {
auto data = writeBuffer_.getContents();
string header = getHeader(data.length);
transport_.write(cast(const(ubyte)[]) header);
transport_.write(data);
transport_.flush();
// Reset the buffer and header variables.
writeBuffer_.reset();
readHeaders_ = true;
}
/**
* The size of the buffer to read HTTP requests into, in bytes. Will expand
* as required.
*/
enum HTTP_BUFFER_SIZE = 1024;
protected:
abstract string getHeader(size_t dataLength);
abstract bool parseStatusLine(const(ubyte)[] status);
void parseHeader(const(ubyte)[] header) {
auto split = findSplit(header, [':']);
if (split[1].empty) {
// No colon found.
return;
}
static bool compToLower(ubyte a, ubyte b) {
return toLower(cast(char)a) == toLower(cast(char)b);
}
if (startsWith!compToLower(split[0], cast(ubyte[])"transfer-encoding")) {
if (endsWith!compToLower(split[2], cast(ubyte[])"chunked")) {
chunked_ = true;
}
} else if (startsWith!compToLower(split[0], cast(ubyte[])"content-length")) {
chunked_ = false;
auto lengthString = stripLeft(cast(const(char)[])split[2]);
contentLength_ = parse!size_t(lengthString);
}
}
private:
ubyte[] readLine() {
while (true) {
auto split = findSplit(httpBufRemaining_, cast(ubyte[])"\r\n");
if (split[1].empty) {
// No CRLF yet, move whatever we have now to front and refill.
if (httpBufRemaining_.empty) {
httpBufRemaining_ = httpBuf_[0 .. 0];
} else {
httpBuf_[0 .. httpBufRemaining_.length] = httpBufRemaining_;
httpBufRemaining_ = httpBuf_[0 .. httpBufRemaining_.length];
}
if (!refill()) {
auto buf = httpBufRemaining_;
httpBufRemaining_ = httpBufRemaining_[$ - 1 .. $ - 1];
return buf;
}
} else {
// Set the remaining buffer to the part after \r\n and return the part
// (line) before it.
httpBufRemaining_ = split[2];
return split[0];
}
}
}
void readHeaders() {
// Initialize headers state variables
contentLength_ = 0;
chunked_ = false;
chunkedDone_ = false;
chunkSize_ = 0;
// Control state flow
bool statusLine = true;
bool finished;
// Loop until headers are finished
while (true) {
auto line = readLine();
if (line.length == 0) {
if (finished) {
readHeaders_ = false;
return;
} else {
// Must have been an HTTP 100, keep going for another status line
statusLine = true;
}
} else {
if (statusLine) {
statusLine = false;
finished = parseStatusLine(line);
} else {
parseHeader(line);
}
}
}
}
size_t readChunked() {
size_t length;
auto line = readLine();
size_t chunkSize;
try {
auto charLine = cast(char[])line;
chunkSize = parse!size_t(charLine, 16);
} catch (Exception e) {
throw new TTransportException("Invalid chunk size: " ~ to!string(line),
TTransportException.Type.CORRUPTED_DATA);
}
if (chunkSize == 0) {
readChunkedFooters();
} else {
// Read data content
length += readContent(chunkSize);
// Read trailing CRLF after content
readLine();
}
return length;
}
void readChunkedFooters() {
while (true) {
auto line = readLine();
if (line.length == 0) {
chunkedDone_ = true;
break;
}
}
}
size_t readContent(size_t size) {
auto need = size;
while (need > 0) {
if (httpBufRemaining_.length == 0) {
// We have given all the data, reset position to head of the buffer.
httpBufRemaining_ = httpBuf_[0 .. 0];
if (!refill()) return size - need;
}
auto give = min(httpBufRemaining_.length, need);
readBuffer_.write(cast(ubyte[])httpBufRemaining_[0 .. give]);
httpBufRemaining_ = httpBufRemaining_[give .. $];
need -= give;
}
return size;
}
bool refill() {
// Is there a nicer way to do this?
auto indexBegin = httpBufRemaining_.ptr - httpBuf_.ptr;
auto indexEnd = indexBegin + httpBufRemaining_.length;
if (httpBuf_.length - indexEnd <= (httpBuf_.length / 4)) {
httpBuf_.length *= 2;
}
// Read more data.
auto got = transport_.read(cast(ubyte[])httpBuf_[indexEnd .. $]);
if (got == 0) return false;
httpBufRemaining_ = httpBuf_[indexBegin .. indexEnd + got];
return true;
}
TTransport transport_;
TMemoryBuffer writeBuffer_;
TMemoryBuffer readBuffer_;
bool readHeaders_;
bool chunked_;
bool chunkedDone_;
size_t chunkSize_;
size_t contentLength_;
ubyte[] httpBuf_;
ubyte[] httpBufRemaining_;
}
/**
* HTTP client transport.
*/
final class TClientHttpTransport : THttpTransport {
/**
* Constructs a client http transport operating on the passed underlying
* transport.
*
* Params:
* transport = The underlying transport used for the actual I/O.
* host = The HTTP host string.
* path = The HTTP path string.
*/
this(TTransport transport, string host, string path) {
super(transport);
host_ = host;
path_ = path;
}
/**
* Convenience overload for constructing a client HTTP transport using a
* TSocket connecting to the specified host and port.
*
* Params:
* host = The server to connect to, also used as HTTP host string.
* port = The port to connect to.
* path = The HTTP path string.
*/
this(string host, ushort port, string path) {
this(new TSocket(host, port), host, path);
}
protected:
override string getHeader(size_t dataLength) {
return "POST " ~ path_ ~ " HTTP/1.1\r\n" ~
"Host: " ~ host_ ~ "\r\n" ~
"Content-Type: application/x-thrift\r\n" ~
"Content-Length: " ~ to!string(dataLength) ~ "\r\n" ~
"Accept: application/x-thrift\r\n"
"User-Agent: Thrift/" ~ VERSION ~ " (D/TClientHttpTransport)\r\n" ~
"\r\n";
}
override bool parseStatusLine(const(ubyte)[] status) {
// HTTP-Version SP Status-Code SP Reason-Phrase CRLF
auto firstSplit = findSplit(status, [' ']);
if (firstSplit[1].empty) {
throw new TTransportException("Bad status: " ~ to!string(status),
TTransportException.Type.CORRUPTED_DATA);
}
auto codeReason = firstSplit[2][countUntil!"a != b"(firstSplit[2], ' ') .. $];
auto secondSplit = findSplit(codeReason, [' ']);
if (secondSplit[1].empty) {
throw new TTransportException("Bad status: " ~ to!string(status),
TTransportException.Type.CORRUPTED_DATA);
}
if (secondSplit[0] == "200") {
// HTTP 200 = OK, we got the response
return true;
} else if (secondSplit[0] == "100") {
// HTTP 100 = continue, just keep reading
return false;
}
throw new TTransportException("Bad status (unhandled status code): " ~
to!string(cast(const(char[]))status), TTransportException.Type.CORRUPTED_DATA);
}
private:
string host_;
string path_;
}
/**
* HTTP server transport.
*/
final class TServerHttpTransport : THttpTransport {
/**
* Constructs a new instance.
*
* Param:
* transport = The underlying transport used for the actual I/O.
*/
this(TTransport transport) {
super(transport);
}
protected:
override string getHeader(size_t dataLength) {
return "HTTP/1.1 200 OK\r\n" ~
"Date: " ~ getRFC1123Time() ~ "\r\n" ~
"Server: Thrift/" ~ VERSION ~ "\r\n" ~
"Content-Type: application/x-thrift\r\n" ~
"Content-Length: " ~ to!string(dataLength) ~ "\r\n" ~
"Connection: Keep-Alive\r\n" ~
"\r\n";
}
override bool parseStatusLine(const(ubyte)[] status) {
// Method SP Request-URI SP HTTP-Version CRLF.
auto split = findSplit(status, [' ']);
if (split[1].empty) {
throw new TTransportException("Bad status: " ~ to!string(status),
TTransportException.Type.CORRUPTED_DATA);
}
auto uriVersion = split[2][countUntil!"a != b"(split[2], ' ') .. $];
if (!canFind(uriVersion, ' ')) {
throw new TTransportException("Bad status: " ~ to!string(status),
TTransportException.Type.CORRUPTED_DATA);
}
if (split[0] == "POST") {
// POST method ok, looking for content.
return true;
}
throw new TTransportException("Bad status (unsupported method): " ~
to!string(status), TTransportException.Type.CORRUPTED_DATA);
}
}
/**
* Wraps a transport into a HTTP server protocol.
*/
alias TWrapperTransportFactory!TServerHttpTransport TServerHttpTransportFactory;
private {
import std.string : format;
string getRFC1123Time() {
auto sysTime = Clock.currTime(UTC());
auto dayName = capMemberName(sysTime.dayOfWeek);
auto monthName = capMemberName(sysTime.month);
return format("%s, %s %s %s %s:%s:%s GMT", dayName, sysTime.day,
monthName, sysTime.year, sysTime.hour, sysTime.minute, sysTime.second);
}
import std.ascii : toUpper;
import std.traits : EnumMembers;
string capMemberName(T)(T val) if (is(T == enum)) {
foreach (i, e; EnumMembers!T) {
enum name = __traits(derivedMembers, T)[i];
enum capName = cast(char) toUpper(name[0]) ~ name [1 .. $];
if (val == e) {
return capName;
}
}
throw new Exception("Not a member of " ~ T.stringof ~ ": " ~ to!string(val));
}
unittest {
enum Foo {
bar,
bAZ
}
import std.exception;
enforce(capMemberName(Foo.bar) == "Bar");
enforce(capMemberName(Foo.bAZ) == "BAZ");
}
}

View file

@ -0,0 +1,233 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.transport.memory;
import core.exception : onOutOfMemoryError;
import core.stdc.stdlib : free, realloc;
import std.algorithm : min;
import std.conv : text;
import thrift.transport.base;
/**
* A transport that simply reads from and writes to an in-memory buffer. Every
* time you call write on it, the data is simply placed into a buffer, and
* every time you call read, data is consumed from that buffer.
*
* Currently, the storage for written data is never reclaimed, even if the
* buffer contents have already been read out again.
*/
final class TMemoryBuffer : TBaseTransport {
/**
* Constructs a new memory transport with an empty internal buffer.
*/
this() {}
/**
* Constructs a new memory transport with an empty internal buffer,
* reserving space for capacity bytes in advance.
*
* If the amount of data which will be written to the buffer is already
* known on construction, this can better performance over the default
* constructor because reallocations can be avoided.
*
* If the preallocated buffer is exhausted, data can still be written to the
* transport, but reallocations will happen.
*
* Params:
* capacity = Size of the initially reserved buffer (in bytes).
*/
this(size_t capacity) {
reset(capacity);
}
/**
* Constructs a new memory transport initially containing the passed data.
*
* For now, the passed buffer is not intelligently used, the data is just
* copied to the internal buffer.
*
* Params:
* buffer = Initial contents available to be read.
*/
this(in ubyte[] contents) {
auto size = contents.length;
reset(size);
buffer_[0 .. size] = contents[];
writeOffset_ = size;
}
/**
* Destructor, frees the internally allocated buffer.
*/
~this() {
free(buffer_);
}
/**
* Returns a read-only view of the current buffer contents.
*
* Note: For performance reasons, the returned slice is only valid for the
* life of this object, and may be invalidated on the next write() call at
* will you might want to immediately .dup it if you intend to keep it
* around.
*/
const(ubyte)[] getContents() {
return buffer_[readOffset_ .. writeOffset_];
}
/**
* A memory transport is always open.
*/
override bool isOpen() @property {
return true;
}
override bool peek() {
return writeOffset_ - readOffset_ > 0;
}
/**
* Opening is a no-op() for a memory buffer.
*/
override void open() {}
/**
* Closing is a no-op() for a memory buffer, it is always open.
*/
override void close() {}
override size_t read(ubyte[] buf) {
auto size = min(buf.length, writeOffset_ - readOffset_);
buf[0 .. size] = buffer_[readOffset_ .. readOffset_ + size];
readOffset_ += size;
return size;
}
/**
* Shortcut version of readAll() using this over TBaseTransport.readAll()
* can give us a nice speed increase because gives us a nice speed increase
* because it is typically a very hot path during deserialization.
*/
override void readAll(ubyte[] buf) {
auto available = writeOffset_ - readOffset_;
if (buf.length > available) {
throw new TTransportException(text("Cannot readAll() ", buf.length,
" bytes of data because only ", available, " bytes are available."),
TTransportException.Type.END_OF_FILE);
}
buf[] = buffer_[readOffset_ .. readOffset_ + buf.length];
readOffset_ += buf.length;
}
override void write(in ubyte[] buf) {
auto need = buf.length;
if (bufferLen_ - writeOffset_ < need) {
// Exponential growth.
auto newLen = bufferLen_ + 1;
while (newLen - writeOffset_ < need) newLen *= 2;
cRealloc(buffer_, newLen);
bufferLen_ = newLen;
}
buffer_[writeOffset_ .. writeOffset_ + need] = buf[];
writeOffset_ += need;
}
override const(ubyte)[] borrow(ubyte* buf, size_t len) {
if (len <= writeOffset_ - readOffset_) {
return buffer_[readOffset_ .. writeOffset_];
} else {
return null;
}
}
override void consume(size_t len) {
readOffset_ += len;
}
void reset() {
readOffset_ = 0;
writeOffset_ = 0;
}
void reset(size_t capacity) {
readOffset_ = 0;
writeOffset_ = 0;
if (bufferLen_ < capacity) {
cRealloc(buffer_, capacity);
bufferLen_ = capacity;
}
}
private:
ubyte* buffer_;
size_t bufferLen_;
size_t readOffset_;
size_t writeOffset_;
}
private {
void cRealloc(ref ubyte* data, size_t newSize) {
auto result = realloc(data, newSize);
if (result is null) onOutOfMemoryError();
data = cast(ubyte*)result;
}
}
version (unittest) {
import std.exception;
}
unittest {
auto a = new TMemoryBuffer(5);
immutable(ubyte[]) testData = [1, 2, 3, 4];
auto buf = new ubyte[testData.length];
enforce(a.isOpen);
// a should be empty.
enforce(!a.peek());
enforce(a.read(buf) == 0);
assertThrown!TTransportException(a.readAll(buf));
// Write some data and read it back again.
a.write(testData);
enforce(a.peek());
enforce(a.getContents() == testData);
enforce(a.read(buf) == testData.length);
enforce(buf == testData);
// a should be empty again.
enforce(!a.peek());
enforce(a.read(buf) == 0);
assertThrown!TTransportException(a.readAll(buf));
// Test the constructor which directly accepts initial data.
auto b = new TMemoryBuffer(testData);
enforce(b.isOpen);
enforce(b.peek());
enforce(b.getContents() == testData);
// Test borrow().
auto borrowed = b.borrow(null, testData.length);
enforce(borrowed == testData);
enforce(b.peek());
b.consume(testData.length);
enforce(!b.peek());
}

View file

@ -0,0 +1,219 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.transport.piped;
import thrift.transport.base;
import thrift.transport.memory;
/**
* Pipes data request from one transport to another when readEnd()
* or writeEnd() is called.
*
* A typical use case would be to log requests on e.g. a socket to
* disk (i. e. pipe them to a TFileWriterTransport).
*
* The implementation keeps an internal buffer which expands to
* hold the whole amount of data read/written until the corresponding *End()
* method is called.
*
* Contrary to the C++ implementation, this doesn't introduce yet another layer
* of input/output buffering, all calls are passed to the underlying source
* transport verbatim.
*/
final class TPipedTransport(Source = TTransport) if (
isTTransport!Source
) : TBaseTransport {
/// The default initial buffer size if not explicitly specified, in bytes.
enum DEFAULT_INITIAL_BUFFER_SIZE = 512;
/**
* Constructs a new instance.
*
* By default, only reads are piped (pipeReads = true, pipeWrites = false).
*
* Params:
* srcTrans = The transport to which all requests are forwarded.
* dstTrans = The transport the read/written data is copied to.
* initialBufferSize = The default size of the read/write buffers, for
* performance tuning.
*/
this(Source srcTrans, TTransport dstTrans,
size_t initialBufferSize = DEFAULT_INITIAL_BUFFER_SIZE
) {
srcTrans_ = srcTrans;
dstTrans_ = dstTrans;
readBuffer_ = new TMemoryBuffer(initialBufferSize);
writeBuffer_ = new TMemoryBuffer(initialBufferSize);
pipeReads_ = true;
pipeWrites_ = false;
}
bool pipeReads() @property const {
return pipeReads_;
}
void pipeReads(bool value) @property {
if (!value) {
readBuffer_.reset();
}
pipeReads_ = value;
}
bool pipeWrites() @property const {
return pipeWrites_;
}
void pipeWrites(bool value) @property {
if (!value) {
writeBuffer_.reset();
}
pipeWrites_ = value;
}
override bool isOpen() {
return srcTrans_.isOpen();
}
override bool peek() {
return srcTrans_.peek();
}
override void open() {
srcTrans_.open();
}
override void close() {
srcTrans_.close();
}
override size_t read(ubyte[] buf) {
auto bytesRead = srcTrans_.read(buf);
if (pipeReads_) {
readBuffer_.write(buf[0 .. bytesRead]);
}
return bytesRead;
}
override size_t readEnd() {
if (pipeReads_) {
auto data = readBuffer_.getContents();
dstTrans_.write(data);
dstTrans_.flush();
readBuffer_.reset();
srcTrans_.readEnd();
// Return data.length instead of the readEnd() result of the source
// transports because it might not be available from it.
return data.length;
}
return srcTrans_.readEnd();
}
override void write(in ubyte[] buf) {
if (pipeWrites_) {
writeBuffer_.write(buf);
}
srcTrans_.write(buf);
}
override size_t writeEnd() {
if (pipeWrites_) {
auto data = writeBuffer_.getContents();
dstTrans_.write(data);
dstTrans_.flush();
writeBuffer_.reset();
srcTrans_.writeEnd();
// Return data.length instead of the readEnd() result of the source
// transports because it might not be available from it.
return data.length;
}
return srcTrans_.writeEnd();
}
override void flush() {
srcTrans_.flush();
}
private:
Source srcTrans_;
TTransport dstTrans_;
TMemoryBuffer readBuffer_;
TMemoryBuffer writeBuffer_;
bool pipeReads_;
bool pipeWrites_;
}
/**
* TPipedTransport construction helper to avoid having to explicitly
* specify the transport types, i.e. to allow the constructor being called
* using IFTI (see $(DMDBUG 6082, D Bugzilla enhancement request 6082)).
*/
TPipedTransport!Source tPipedTransport(Source)(
Source srcTrans, TTransport dstTrans
) if (isTTransport!Source) {
return new typeof(return)(srcTrans, dstTrans);
}
version (unittest) {
// DMD @@BUG@@: UFCS for std.array.empty doesn't work when import is moved
// into unittest block.
import std.array;
import std.exception : enforce;
}
unittest {
auto underlying = new TMemoryBuffer;
auto pipeTarget = new TMemoryBuffer;
auto trans = tPipedTransport(underlying, pipeTarget);
underlying.write(cast(ubyte[])"abcd");
ubyte[4] buffer;
trans.readAll(buffer[0 .. 2]);
enforce(buffer[0 .. 2] == "ab");
enforce(pipeTarget.getContents().empty);
trans.readEnd();
enforce(pipeTarget.getContents() == "ab");
pipeTarget.reset();
underlying.write(cast(ubyte[])"ef");
trans.readAll(buffer[0 .. 2]);
enforce(buffer[0 .. 2] == "cd");
enforce(pipeTarget.getContents().empty);
trans.readAll(buffer[0 .. 2]);
enforce(buffer[0 .. 2] == "ef");
enforce(pipeTarget.getContents().empty);
trans.readEnd();
enforce(pipeTarget.getContents() == "cdef");
}

View file

@ -0,0 +1,147 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/**
* Transports which operate on generic D ranges.
*/
module thrift.transport.range;
import std.array : empty;
import std.range;
import std.traits : Unqual;
import thrift.transport.base;
/**
* Adapts an ubyte input range for reading via the TTransport interface.
*
* The case where R is a plain ubyte[] is reasonably optimized, so a possible
* use case for TInputRangeTransport would be to deserialize some data held in
* a memory buffer.
*/
final class TInputRangeTransport(R) if (
isInputRange!(Unqual!R) && is(ElementType!R : const(ubyte))
) : TBaseTransport {
/**
* Constructs a new instance.
*
* Params:
* data = The input range to use as data.
*/
this(R data) {
data_ = data;
}
/**
* An input range transport is always open.
*/
override bool isOpen() @property {
return true;
}
override bool peek() {
return !data_.empty;
}
/**
* Opening is a no-op() for an input range transport.
*/
override void open() {}
/**
* Closing is a no-op() for a memory buffer.
*/
override void close() {}
override size_t read(ubyte[] buf) {
auto data = data_.take(buf.length);
auto bytes = data.length;
static if (is(typeof(R.init[1 .. 2]) : const(ubyte)[])) {
// put() is currently unnecessarily slow if both ranges are sliceable.
buf[0 .. bytes] = data[];
data_ = data_[bytes .. $];
} else {
buf.put(data);
}
return bytes;
}
/**
* Shortcut version of readAll() for slicable ranges.
*
* Because readAll() is typically a very hot path during deserialization,
* using this over TBaseTransport.readAll() gives us a nice increase in
* speed due to the reduced amount of indirections.
*/
override void readAll(ubyte[] buf) {
static if (is(typeof(R.init[1 .. 2]) : const(ubyte)[])) {
if (buf.length <= data_.length) {
buf[] = data_[0 .. buf.length];
data_ = data_[buf.length .. $];
return;
}
}
super.readAll(buf);
}
override const(ubyte)[] borrow(ubyte* buf, size_t len) {
static if (is(R : const(ubyte)[])) {
// Can only borrow if our data type is actually an ubyte array.
if (len <= data_.length) {
return data_;
}
}
return null;
}
override void consume(size_t len) {
static if (is(R : const(ubyte)[])) {
if (len > data_.length) {
throw new TTransportException("Invalid consume length",
TTransportException.Type.BAD_ARGS);
}
data_ = data_[len .. $];
} else {
super.consume(len);
}
}
/**
* Sets a new data range to use.
*/
void reset(R data) {
data_ = data;
}
private:
R data_;
}
/**
* TInputRangeTransport construction helper to avoid having to explicitly
* specify the argument type, i.e. to allow the constructor being called using
* IFTI (see $(LINK2 http://d.puremagic.com/issues/show_bug.cgi?id=6082, D
* Bugzilla enhancement requet 6082)).
*/
TInputRangeTransport!R tInputRangeTransport(R)(R data) if (
is (TInputRangeTransport!R)
) {
return new TInputRangeTransport!R(data);
}

View file

@ -0,0 +1,453 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.transport.socket;
import core.thread : Thread;
import core.time : dur, Duration;
import std.array : empty;
import std.conv : text, to;
import std.exception : enforce;
import std.socket;
import thrift.base;
import thrift.transport.base;
import thrift.internal.socket;
/**
* Common parts of a socket TTransport implementation, regardless of how the
* actual I/O is performed (sync/async).
*/
abstract class TSocketBase : TBaseTransport {
/**
* Constructor that takes an already created, connected (!) socket.
*
* Params:
* socket = Already created, connected socket object.
*/
this(Socket socket) {
socket_ = socket;
setSocketOpts();
}
/**
* Creates a new unconnected socket that will connect to the given host
* on the given port.
*
* Params:
* host = Remote host.
* port = Remote port.
*/
this(string host, ushort port) {
host_ = host;
port_ = port;
}
/**
* Checks whether the socket is connected.
*/
override bool isOpen() @property {
return socket_ !is null;
}
/**
* Writes as much data to the socket as there can be in a single OS call.
*
* Params:
* buf = Data to write.
*
* Returns: The actual number of bytes written. Never more than buf.length.
*/
abstract size_t writeSome(in ubyte[] buf) out (written) {
// DMD @@BUG@@: Enabling this e.g. fails the contract in the
// async_test_server, because buf.length evaluates to 0 here, even though
// in the method body it correctly is 27 (equal to the return value).
version (none) assert(written <= buf.length, text("Implementation wrote " ~
"more data than requested to?! (", written, " vs. ", buf.length, ")"));
} body {
assert(0, "DMD bug? Why would contracts work for interfaces, but not "
"for abstract methods? "
"(Error: function […] in and out contracts require function body");
}
/**
* Returns the actual address of the peer the socket is connected to.
*
* In contrast, the host and port properties contain the address used to
* establish the connection, and are not updated after the connection.
*
* The socket must be open when calling this.
*/
Address getPeerAddress() {
enforce(isOpen, new TTransportException("Cannot get peer host for " ~
"closed socket.", TTransportException.Type.NOT_OPEN));
if (!peerAddress_) {
peerAddress_ = socket_.remoteAddress();
assert(peerAddress_);
}
return peerAddress_;
}
/**
* The host the socket is connected to or will connect to. Null if an
* already connected socket was used to construct the object.
*/
string host() const @property {
return host_;
}
/**
* The port the socket is connected to or will connect to. Zero if an
* already connected socket was used to construct the object.
*/
ushort port() const @property {
return port_;
}
/// The socket send timeout.
Duration sendTimeout() const @property {
return sendTimeout_;
}
/// Ditto
void sendTimeout(Duration value) @property {
sendTimeout_ = value;
}
/// The socket receiving timeout. Values smaller than 500 ms are not
/// supported on Windows.
Duration recvTimeout() const @property {
return recvTimeout_;
}
/// Ditto
void recvTimeout(Duration value) @property {
recvTimeout_ = value;
}
/**
* Returns the OS handle of the underlying socket.
*
* Should not usually be used directly, but access to it can be necessary
* to interface with C libraries.
*/
typeof(socket_.handle()) socketHandle() @property {
return socket_.handle();
}
protected:
/**
* Sets the needed socket options.
*/
void setSocketOpts() {
try {
alias SocketOptionLevel.SOCKET lvlSock;
Linger l;
l.on = 0;
l.time = 0;
socket_.setOption(lvlSock, SocketOption.LINGER, l);
} catch (SocketException e) {
logError("Could not set socket option: %s", e);
}
// Just try to disable Nagle's algorithm this will fail if we are passed
// in a non-TCP socket via the Socket-accepting constructor.
try {
socket_.setOption(SocketOptionLevel.TCP, SocketOption.TCP_NODELAY, true);
} catch (SocketException e) {}
}
/// Remote host.
string host_;
/// Remote port.
ushort port_;
/// Timeout for sending.
Duration sendTimeout_;
/// Timeout for receiving.
Duration recvTimeout_;
/// Cached peer address.
Address peerAddress_;
/// Cached peer host name.
string peerHost_;
/// Cached peer port.
ushort peerPort_;
/// Wrapped socket object.
Socket socket_;
}
/**
* Socket implementation of the TTransport interface.
*
* Due to the limitations of std.socket, currently only TCP/IP sockets are
* supported (i.e. Unix domain sockets are not).
*/
class TSocket : TSocketBase {
///
this(Socket socket) {
super(socket);
}
///
this(string host, ushort port) {
super(host, port);
}
/**
* Connects the socket.
*/
override void open() {
if (isOpen) return;
enforce(!host_.empty, new TTransportException(
"Cannot open socket to null host.", TTransportException.Type.NOT_OPEN));
enforce(port_ != 0, new TTransportException(
"Cannot open socket to port zero.", TTransportException.Type.NOT_OPEN));
Address[] addrs;
try {
addrs = getAddress(host_, port_);
} catch (SocketException e) {
throw new TTransportException("Could not resolve given host string.",
TTransportException.Type.NOT_OPEN, __FILE__, __LINE__, e);
}
Exception[] errors;
foreach (addr; addrs) {
try {
socket_ = new TcpSocket(addr.addressFamily);
setSocketOpts();
socket_.connect(addr);
break;
} catch (SocketException e) {
errors ~= e;
}
}
if (errors.length == addrs.length) {
socket_ = null;
// Need to throw a TTransportException to abide the TTransport API.
import std.algorithm, std.range;
throw new TTransportException(
text("Failed to connect to ", host_, ":", port_, "."),
TTransportException.Type.NOT_OPEN,
__FILE__, __LINE__,
new TCompoundOperationException(
text(
"All addresses tried failed (",
joiner(map!q{text(a._0, `: "`, a._1.msg, `"`)}(zip(addrs, errors)), ", "),
")."
),
errors
)
);
}
}
/**
* Closes the socket.
*/
override void close() {
if (!isOpen) return;
socket_.close();
socket_ = null;
}
override bool peek() {
if (!isOpen) return false;
ubyte buf;
auto r = socket_.receive((&buf)[0 .. 1], SocketFlags.PEEK);
if (r == -1) {
auto lastErrno = getSocketErrno();
static if (connresetOnPeerShutdown) {
if (lastErrno == ECONNRESET) {
close();
return false;
}
}
throw new TTransportException("Peeking into socket failed: " ~
socketErrnoString(lastErrno), TTransportException.Type.UNKNOWN);
}
return (r > 0);
}
override size_t read(ubyte[] buf) {
enforce(isOpen, new TTransportException(
"Cannot read if socket is not open.", TTransportException.Type.NOT_OPEN));
typeof(getSocketErrno()) lastErrno;
ushort tries;
while (tries++ <= maxRecvRetries_) {
auto r = socket_.receive(cast(void[])buf);
// If recv went fine, immediately return.
if (r >= 0) return r;
// Something went wrong, find out how to handle it.
lastErrno = getSocketErrno();
if (lastErrno == INTERRUPTED_ERRNO) {
// If the syscall was interrupted, just try again.
continue;
}
static if (connresetOnPeerShutdown) {
// See top comment.
if (lastErrno == ECONNRESET) {
return 0;
}
}
// Not an error which is handled in a special way, just leave the loop.
break;
}
if (isSocketCloseErrno(lastErrno)) {
close();
throw new TTransportException("Receiving failed, closing socket: " ~
socketErrnoString(lastErrno), TTransportException.Type.NOT_OPEN);
} else if (lastErrno == TIMEOUT_ERRNO) {
throw new TTransportException(TTransportException.Type.TIMED_OUT);
} else {
throw new TTransportException("Receiving from socket failed: " ~
socketErrnoString(lastErrno), TTransportException.Type.UNKNOWN);
}
}
override void write(in ubyte[] buf) {
size_t sent;
while (sent < buf.length) {
auto b = writeSome(buf[sent .. $]);
if (b == 0) {
// This should only happen if the timeout set with SO_SNDTIMEO expired.
throw new TTransportException("send() timeout expired.",
TTransportException.Type.TIMED_OUT);
}
sent += b;
}
assert(sent == buf.length);
}
override size_t writeSome(in ubyte[] buf) {
enforce(isOpen, new TTransportException(
"Cannot write if file is not open.", TTransportException.Type.NOT_OPEN));
auto r = socket_.send(buf);
// Everything went well, just return the number of bytes written.
if (r > 0) return r;
// Handle error conditions.
if (r < 0) {
auto lastErrno = getSocketErrno();
if (lastErrno == WOULD_BLOCK_ERRNO) {
// Not an exceptional error per se even with blocking sockets,
// EAGAIN apparently is returned sometimes on out-of-resource
// conditions (see the C++ implementation for details). Also, this
// allows using TSocket with non-blocking sockets e.g. in
// TNonblockingServer.
return 0;
}
auto type = TTransportException.Type.UNKNOWN;
if (isSocketCloseErrno(lastErrno)) {
type = TTransportException.Type.NOT_OPEN;
close();
}
throw new TTransportException("Sending to socket failed: " ~
socketErrnoString(lastErrno), type);
}
// send() should never return 0.
throw new TTransportException("Sending to socket failed (0 bytes written).",
TTransportException.Type.UNKNOWN);
}
override void sendTimeout(Duration value) @property {
super.sendTimeout(value);
setTimeout(SocketOption.SNDTIMEO, value);
}
override void recvTimeout(Duration value) @property {
super.recvTimeout(value);
setTimeout(SocketOption.RCVTIMEO, value);
}
/**
* Maximum number of retries for receiving from socket on read() in case of
* EAGAIN/EINTR.
*/
ushort maxRecvRetries() @property const {
return maxRecvRetries_;
}
/// Ditto
void maxRecvRetries(ushort value) @property {
maxRecvRetries_ = value;
}
/// Ditto
enum DEFAULT_MAX_RECV_RETRIES = 5;
protected:
override void setSocketOpts() {
super.setSocketOpts();
setTimeout(SocketOption.SNDTIMEO, sendTimeout_);
setTimeout(SocketOption.RCVTIMEO, recvTimeout_);
}
void setTimeout(SocketOption type, Duration value) {
assert(type == SocketOption.SNDTIMEO || type == SocketOption.RCVTIMEO);
version (Win32) {
if (value > dur!"hnsecs"(0) && value < dur!"msecs"(500)) {
logError(
"Socket %s timeout of %s ms might be raised to 500 ms on Windows.",
(type == SocketOption.SNDTIMEO) ? "send" : "receive",
value.total!"msecs"
);
}
}
if (socket_) {
try {
socket_.setOption(SocketOptionLevel.SOCKET, type, value);
} catch (SocketException e) {
throw new TTransportException(
"Could not set timeout.",
TTransportException.Type.UNKNOWN,
__FILE__,
__LINE__,
e
);
}
}
}
/// Maximum number of recv() retries.
ushort maxRecvRetries_ = DEFAULT_MAX_RECV_RETRIES;
}

View file

@ -0,0 +1,680 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/**
* OpenSSL socket implementation, in large parts ported from C++.
*/
module thrift.transport.ssl;
import core.exception : onOutOfMemoryError;
import core.stdc.errno : errno, EINTR;
import core.sync.mutex : Mutex;
import core.memory : GC;
import core.stdc.config;
import core.stdc.stdlib : free, malloc;
import std.ascii : toUpper;
import std.array : empty, front, popFront;
import std.conv : emplace, to;
import std.exception : enforce;
import std.socket : Address, InternetAddress, Internet6Address, Socket;
import std.string : toStringz;
import deimos.openssl.err;
import deimos.openssl.rand;
import deimos.openssl.ssl;
import deimos.openssl.x509v3;
import thrift.base;
import thrift.internal.ssl;
import thrift.transport.base;
import thrift.transport.socket;
/**
* SSL encrypted socket implementation using OpenSSL.
*
* Note:
* On Posix systems which do not have the BSD-specific SO_NOSIGPIPE flag, you
* might want to ignore the SIGPIPE signal, as OpenSSL might try to write to
* a closed socket if the peer disconnects abruptly:
* ---
* import core.stdc.signal;
* import core.sys.posix.signal;
* signal(SIGPIPE, SIG_IGN);
* ---
*/
final class TSSLSocket : TSocket {
/**
* Creates an instance that wraps an already created, connected (!) socket.
*
* Params:
* context = The SSL socket context to use. A reference to it is stored so
* that it doesn't get cleaned up while the socket is used.
* socket = Already created, connected socket object.
*/
this(TSSLContext context, Socket socket) {
super(socket);
context_ = context;
serverSide_ = context.serverSide;
accessManager_ = context.accessManager;
}
/**
* Creates a new unconnected socket that will connect to the given host
* on the given port.
*
* Params:
* context = The SSL socket context to use. A reference to it is stored so
* that it doesn't get cleaned up while the socket is used.
* host = Remote host.
* port = Remote port.
*/
this(TSSLContext context, string host, ushort port) {
super(host, port);
context_ = context;
serverSide_ = context.serverSide;
accessManager_ = context.accessManager;
}
override bool isOpen() @property {
if (ssl_ is null || !super.isOpen()) return false;
auto shutdown = SSL_get_shutdown(ssl_);
bool shutdownReceived = (shutdown & SSL_RECEIVED_SHUTDOWN) != 0;
bool shutdownSent = (shutdown & SSL_SENT_SHUTDOWN) != 0;
return !(shutdownReceived && shutdownSent);
}
override bool peek() {
if (!isOpen) return false;
checkHandshake();
byte bt;
auto rc = SSL_peek(ssl_, &bt, bt.sizeof);
enforce(rc >= 0, getSSLException("SSL_peek"));
if (rc == 0) {
ERR_clear_error();
}
return (rc > 0);
}
override void open() {
enforce(!serverSide_, "Cannot open a server-side SSL socket.");
if (isOpen) return;
super.open();
}
override void close() {
if (!isOpen) return;
if (ssl_ !is null) {
// Two-step SSL shutdown.
auto rc = SSL_shutdown(ssl_);
if (rc == 0) {
rc = SSL_shutdown(ssl_);
}
if (rc < 0) {
// Do not throw an exception here as leaving the transport "open" will
// probably produce only more errors, and the chance we can do
// something about the error e.g. by retrying is very low.
logError("Error shutting down SSL: %s", getSSLException());
}
SSL_free(ssl_);
ssl_ = null;
ERR_remove_state(0);
}
super.close();
}
override size_t read(ubyte[] buf) {
checkHandshake();
int bytes;
foreach (_; 0 .. maxRecvRetries) {
bytes = SSL_read(ssl_, buf.ptr, cast(int)buf.length);
if (bytes >= 0) break;
auto errnoCopy = errno;
if (SSL_get_error(ssl_, bytes) == SSL_ERROR_SYSCALL) {
if (ERR_get_error() == 0 && errnoCopy == EINTR) {
// FIXME: Windows.
continue;
}
}
throw getSSLException("SSL_read");
}
return bytes;
}
override void write(in ubyte[] buf) {
checkHandshake();
// Loop in case SSL_MODE_ENABLE_PARTIAL_WRITE is set in SSL_CTX.
size_t written = 0;
while (written < buf.length) {
auto bytes = SSL_write(ssl_, buf.ptr + written,
cast(int)(buf.length - written));
if (bytes <= 0) {
throw getSSLException("SSL_write");
}
written += bytes;
}
}
override void flush() {
checkHandshake();
auto bio = SSL_get_wbio(ssl_);
enforce(bio !is null, new TSSLException("SSL_get_wbio returned null"));
auto rc = BIO_flush(bio);
enforce(rc == 1, getSSLException("BIO_flush"));
}
/**
* Whether to use client or server side SSL handshake protocol.
*/
bool serverSide() @property const {
return serverSide_;
}
/// Ditto
void serverSide(bool value) @property {
serverSide_ = value;
}
/**
* The access manager to use.
*/
void accessManager(TAccessManager value) @property {
accessManager_ = value;
}
private:
void checkHandshake() {
enforce(super.isOpen(), new TTransportException(
TTransportException.Type.NOT_OPEN));
if (ssl_ !is null) return;
ssl_ = context_.createSSL();
SSL_set_fd(ssl_, socketHandle);
int rc;
if (serverSide_) {
rc = SSL_accept(ssl_);
} else {
rc = SSL_connect(ssl_);
}
enforce(rc > 0, getSSLException());
authorize(ssl_, accessManager_, getPeerAddress(),
(serverSide_ ? getPeerAddress().toHostNameString() : host));
}
bool serverSide_;
SSL* ssl_;
TSSLContext context_;
TAccessManager accessManager_;
}
/**
* Represents an OpenSSL context with certification settings, etc. and handles
* initialization/teardown.
*
* OpenSSL is initialized when the first instance of this class is created
* and shut down when the last one is destroyed (thread-safe).
*/
class TSSLContext {
this() {
initMutex_.lock();
scope(exit) initMutex_.unlock();
if (count_ == 0) {
initializeOpenSSL();
randomize();
}
count_++;
ctx_ = SSL_CTX_new(TLSv1_method());
enforce(ctx_, getSSLException("SSL_CTX_new"));
SSL_CTX_set_mode(ctx_, SSL_MODE_AUTO_RETRY);
}
~this() {
initMutex_.lock();
scope(exit) initMutex_.unlock();
if (ctx_ !is null) {
SSL_CTX_free(ctx_);
ctx_ = null;
}
count_--;
if (count_ == 0) {
cleanupOpenSSL();
}
}
/**
* Ciphers to be used in SSL handshake process.
*
* The string must be in the colon-delimited OpenSSL notation described in
* ciphers(1), for example: "ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH".
*/
void ciphers(string enable) @property {
auto rc = SSL_CTX_set_cipher_list(ctx_, toStringz(enable));
enforce(ERR_peek_error() == 0, getSSLException("SSL_CTX_set_cipher_list"));
enforce(rc > 0, new TSSLException("None of specified ciphers are supported"));
}
/**
* Whether peer is required to present a valid certificate.
*/
void authenticate(bool required) @property {
int mode;
if (required) {
mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT |
SSL_VERIFY_CLIENT_ONCE;
} else {
mode = SSL_VERIFY_NONE;
}
SSL_CTX_set_verify(ctx_, mode, null);
}
/**
* Load server certificate.
*
* Params:
* path = Path to the certificate file.
* format = Certificate file format. Defaults to PEM, which is currently
* the only one supported.
*/
void loadCertificate(string path, string format = "PEM") {
enforce(path !is null && format !is null, new TTransportException(
"loadCertificateChain: either <path> or <format> is null",
TTransportException.Type.BAD_ARGS));
if (format == "PEM") {
enforce(SSL_CTX_use_certificate_chain_file(ctx_, toStringz(path)),
getSSLException(
`Could not load SSL server certificate from file "` ~ path ~ `"`
)
);
} else {
throw new TSSLException("Unsupported certificate format: " ~ format);
}
}
/*
* Load private key.
*
* Params:
* path = Path to the certificate file.
* format = Private key file format. Defaults to PEM, which is currently
* the only one supported.
*/
void loadPrivateKey(string path, string format = "PEM") {
enforce(path !is null && format !is null, new TTransportException(
"loadPrivateKey: either <path> or <format> is NULL",
TTransportException.Type.BAD_ARGS));
if (format == "PEM") {
enforce(SSL_CTX_use_PrivateKey_file(ctx_, toStringz(path), SSL_FILETYPE_PEM),
getSSLException(
`Could not load SSL private key from file "` ~ path ~ `"`
)
);
} else {
throw new TSSLException("Unsupported certificate format: " ~ format);
}
}
/**
* Load trusted certificates from specified file (in PEM format).
*
* Params.
* path = Path to the file containing the trusted certificates.
*/
void loadTrustedCertificates(string path) {
enforce(path !is null, new TTransportException(
"loadTrustedCertificates: <path> is NULL",
TTransportException.Type.BAD_ARGS));
enforce(SSL_CTX_load_verify_locations(ctx_, toStringz(path), null),
getSSLException(
`Could not load SSL trusted certificate list from file "` ~ path ~ `"`
)
);
}
/**
* Called during OpenSSL initialization to seed the OpenSSL entropy pool.
*
* Defaults to simply calling RAND_poll(), but it can be overwritten if a
* different, perhaps more secure implementation is desired.
*/
void randomize() {
RAND_poll();
}
/**
* Whether to use client or server side SSL handshake protocol.
*/
bool serverSide() @property const {
return serverSide_;
}
/// Ditto
void serverSide(bool value) @property {
serverSide_ = value;
}
/**
* The access manager to use.
*/
TAccessManager accessManager() @property {
if (!serverSide_ && !accessManager_) {
accessManager_ = new TDefaultClientAccessManager;
}
return accessManager_;
}
/// Ditto
void accessManager(TAccessManager value) @property {
accessManager_ = value;
}
SSL* createSSL() out (result) {
assert(result);
} body {
auto result = SSL_new(ctx_);
enforce(result, getSSLException("SSL_new"));
return result;
}
protected:
/**
* Override this method for custom password callback. It may be called
* multiple times at any time during a session as necessary.
*
* Params:
* size = Maximum length of password, including null byte.
*/
string getPassword(int size) nothrow out(result) {
assert(result.length < size);
} body {
return "";
}
/**
* Notifies OpenSSL to use getPassword() instead of the default password
* callback with getPassword().
*/
void overrideDefaultPasswordCallback() {
SSL_CTX_set_default_passwd_cb(ctx_, &passwordCallback);
SSL_CTX_set_default_passwd_cb_userdata(ctx_, cast(void*)this);
}
SSL_CTX* ctx_;
private:
bool serverSide_;
TAccessManager accessManager_;
shared static this() {
initMutex_ = new Mutex();
}
static void initializeOpenSSL() {
if (initialized_) {
return;
}
initialized_ = true;
SSL_library_init();
SSL_load_error_strings();
mutexes_ = new Mutex[CRYPTO_num_locks()];
foreach (ref m; mutexes_) {
m = new Mutex;
}
import thrift.internal.traits;
// As per the OpenSSL threads manpage, this isn't needed on Windows.
version (Posix) {
CRYPTO_set_id_callback(assumeNothrow(&threadIdCallback));
}
CRYPTO_set_locking_callback(assumeNothrow(&lockingCallback));
CRYPTO_set_dynlock_create_callback(assumeNothrow(&dynlockCreateCallback));
CRYPTO_set_dynlock_lock_callback(assumeNothrow(&dynlockLockCallback));
CRYPTO_set_dynlock_destroy_callback(assumeNothrow(&dynlockDestroyCallback));
}
static void cleanupOpenSSL() {
if (!initialized_) return;
initialized_ = false;
CRYPTO_set_locking_callback(null);
CRYPTO_set_dynlock_create_callback(null);
CRYPTO_set_dynlock_lock_callback(null);
CRYPTO_set_dynlock_destroy_callback(null);
CRYPTO_cleanup_all_ex_data();
ERR_free_strings();
ERR_remove_state(0);
}
static extern(C) {
version (Posix) {
import core.sys.posix.pthread : pthread_self;
c_ulong threadIdCallback() {
return cast(c_ulong)pthread_self();
}
}
void lockingCallback(int mode, int n, const(char)* file, int line) {
if (mode & CRYPTO_LOCK) {
mutexes_[n].lock();
} else {
mutexes_[n].unlock();
}
}
CRYPTO_dynlock_value* dynlockCreateCallback(const(char)* file, int line) {
enum size = __traits(classInstanceSize, Mutex);
auto mem = malloc(size)[0 .. size];
if (!mem) onOutOfMemoryError();
GC.addRange(mem.ptr, size);
auto mutex = emplace!Mutex(mem);
return cast(CRYPTO_dynlock_value*)mutex;
}
void dynlockLockCallback(int mode, CRYPTO_dynlock_value* l,
const(char)* file, int line)
{
if (l is null) return;
if (mode & CRYPTO_LOCK) {
(cast(Mutex)l).lock();
} else {
(cast(Mutex)l).unlock();
}
}
void dynlockDestroyCallback(CRYPTO_dynlock_value* l,
const(char)* file, int line)
{
GC.removeRange(l);
destroy(cast(Mutex)l);
free(l);
}
int passwordCallback(char* password, int size, int, void* data) nothrow {
auto context = cast(TSSLContext) data;
auto userPassword = context.getPassword(size);
auto len = userPassword.length;
if (len > size) {
len = size;
}
password[0 .. len] = userPassword[0 .. len]; // TODO: \0 handling correct?
return cast(int)len;
}
}
static __gshared bool initialized_;
static __gshared Mutex initMutex_;
static __gshared Mutex[] mutexes_;
static __gshared uint count_;
}
/**
* Decides whether a remote host is legitimate or not.
*
* It is usually set at a TSSLContext, which then passes it to all the created
* TSSLSockets.
*/
class TAccessManager {
///
enum Decision {
DENY = -1, /// Deny access.
SKIP = 0, /// Cannot decide, move on to next check (deny if last).
ALLOW = 1 /// Allow access.
}
/**
* Determines whether a peer should be granted access or not based on its
* IP address.
*
* Called once after SSL handshake is completes successfully and before peer
* certificate is examined.
*
* If a valid decision (ALLOW or DENY) is returned, the peer certificate
* will not be verified.
*/
Decision verify(Address address) {
return Decision.DENY;
}
/**
* Determines whether a peer should be granted access or not based on a
* name from its certificate.
*
* Called every time a DNS subjectAltName/common name is extracted from the
* peer's certificate.
*
* Params:
* host = The actual host name string from the socket connection.
* certHost = A host name string from the certificate.
*/
Decision verify(string host, const(char)[] certHost) {
return Decision.DENY;
}
/**
* Determines whether a peer should be granted access or not based on an IP
* address from its certificate.
*
* Called every time an IP subjectAltName is extracted from the peer's
* certificate.
*
* Params:
* address = The actual address from the socket connection.
* certHost = A host name string from the certificate.
*/
Decision verify(Address address, ubyte[] certAddress) {
return Decision.DENY;
}
}
/**
* Default access manager implementation, which just checks the host name
* resp. IP address of the connection against the certificate.
*/
class TDefaultClientAccessManager : TAccessManager {
override Decision verify(Address address) {
return Decision.SKIP;
}
override Decision verify(string host, const(char)[] certHost) {
if (host.empty || certHost.empty) {
return Decision.SKIP;
}
return (matchName(host, certHost) ? Decision.ALLOW : Decision.SKIP);
}
override Decision verify(Address address, ubyte[] certAddress) {
bool match;
if (certAddress.length == 4) {
if (auto ia = cast(InternetAddress)address) {
match = ((cast(ubyte*)ia.addr())[0 .. 4] == certAddress[]);
}
} else if (certAddress.length == 16) {
if (auto ia = cast(Internet6Address)address) {
match = (ia.addr() == certAddress[]);
}
}
return (match ? Decision.ALLOW : Decision.SKIP);
}
}
private {
/**
* Matches a name with a pattern. The pattern may include wildcard. A single
* wildcard "*" can match up to one component in the domain name.
*
* Params:
* host = Host name to match, typically the SSL remote peer.
* pattern = Host name pattern, typically from the SSL certificate.
*
* Returns: true if host matches pattern, false otherwise.
*/
bool matchName(const(char)[] host, const(char)[] pattern) {
while (!host.empty && !pattern.empty) {
if (toUpper(pattern.front) == toUpper(host.front)) {
host.popFront;
pattern.popFront;
} else if (pattern.front == '*') {
while (!host.empty && host.front != '.') {
host.popFront;
}
pattern.popFront;
} else {
break;
}
}
return (host.empty && pattern.empty);
}
unittest {
enforce(matchName("thrift.apache.org", "*.apache.org"));
enforce(!matchName("thrift.apache.org", "apache.org"));
enforce(matchName("thrift.apache.org", "thrift.*.*"));
enforce(matchName("", ""));
enforce(!matchName("", "*"));
}
}
/**
* SSL-level exception.
*/
class TSSLException : TTransportException {
///
this(string msg, string file = __FILE__, size_t line = __LINE__,
Throwable next = null)
{
super(msg, TTransportException.Type.INTERNAL_ERROR, file, line, next);
}
}

View file

@ -0,0 +1,497 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.transport.zlib;
import core.bitop : bswap;
import etc.c.zlib;
import std.algorithm : min;
import std.array : empty;
import std.conv : to;
import std.exception : enforce;
import thrift.base;
import thrift.transport.base;
/**
* zlib transport. Compresses (deflates) data before writing it to the
* underlying transport, and decompresses (inflates) it after reading.
*/
final class TZlibTransport : TBaseTransport {
// These defaults have yet to be optimized.
enum DEFAULT_URBUF_SIZE = 128;
enum DEFAULT_CRBUF_SIZE = 1024;
enum DEFAULT_UWBUF_SIZE = 128;
enum DEFAULT_CWBUF_SIZE = 1024;
/**
* Constructs a new zlib transport.
*
* Params:
* transport = The underlying transport to wrap.
* urbufSize = The size of the uncompressed reading buffer, in bytes.
* crbufSize = The size of the compressed reading buffer, in bytes.
* uwbufSize = The size of the uncompressed writing buffer, in bytes.
* cwbufSize = The size of the compressed writing buffer, in bytes.
*/
this(
TTransport transport,
size_t urbufSize = DEFAULT_URBUF_SIZE,
size_t crbufSize = DEFAULT_CRBUF_SIZE,
size_t uwbufSize = DEFAULT_UWBUF_SIZE,
size_t cwbufSize = DEFAULT_CWBUF_SIZE
) {
transport_ = transport;
enforce(uwbufSize >= MIN_DIRECT_DEFLATE_SIZE, new TTransportException(
"TZLibTransport: uncompressed write buffer must be at least " ~
to!string(MIN_DIRECT_DEFLATE_SIZE) ~ "bytes in size.",
TTransportException.Type.BAD_ARGS));
urbuf_ = new ubyte[urbufSize];
crbuf_ = new ubyte[crbufSize];
uwbuf_ = new ubyte[uwbufSize];
cwbuf_ = new ubyte[cwbufSize];
rstream_ = new z_stream;
rstream_.next_in = crbuf_.ptr;
rstream_.avail_in = 0;
rstream_.next_out = urbuf_.ptr;
rstream_.avail_out = to!uint(urbuf_.length);
wstream_ = new z_stream;
wstream_.next_in = uwbuf_.ptr;
wstream_.avail_in = 0;
wstream_.next_out = cwbuf_.ptr;
wstream_.avail_out = to!uint(crbuf_.length);
zlibEnforce(inflateInit(rstream_), rstream_);
scope (failure) {
zlibLogError(inflateEnd(rstream_), rstream_);
}
zlibEnforce(deflateInit(wstream_, Z_DEFAULT_COMPRESSION), wstream_);
}
~this() {
zlibLogError(inflateEnd(rstream_), rstream_);
auto result = deflateEnd(wstream_);
// Z_DATA_ERROR may indicate unflushed data, so just ignore it.
if (result != Z_DATA_ERROR) {
zlibLogError(result, wstream_);
}
}
/**
* Returns the wrapped transport.
*/
TTransport underlyingTransport() @property {
return transport_;
}
override bool isOpen() @property {
return readAvail > 0 || transport_.isOpen;
}
override bool peek() {
return readAvail > 0 || transport_.peek();
}
override void open() {
transport_.open();
}
override void close() {
transport_.close();
}
override size_t read(ubyte[] buf) {
// The C++ implementation suggests to skip urbuf on big reads in future
// versions, we would benefit from it as well.
auto origLen = buf.length;
while (true) {
auto give = min(readAvail, buf.length);
// If std.range.put was optimized for slicable ranges, it could be used
// here as well.
buf[0 .. give] = urbuf_[urpos_ .. urpos_ + give];
buf = buf[give .. $];
urpos_ += give;
auto need = buf.length;
if (need == 0) {
// We could manage to get the all the data requested.
return origLen;
}
if (inputEnded_ || (need < origLen && rstream_.avail_in == 0)) {
// We didn't fill buf completely, but there is no more data available.
return origLen - need;
}
// Refill our buffer by reading more data through zlib.
rstream_.next_out = urbuf_.ptr;
rstream_.avail_out = to!uint(urbuf_.length);
urpos_ = 0;
if (!readFromZlib()) {
// Couldn't get more data from the underlying transport.
return origLen - need;
}
}
}
override void write(in ubyte[] buf) {
enforce(!outputFinished_, new TTransportException(
"write() called after finish()", TTransportException.Type.BAD_ARGS));
auto len = buf.length;
if (len > MIN_DIRECT_DEFLATE_SIZE) {
flushToZlib(uwbuf_[0 .. uwpos_], Z_NO_FLUSH);
uwpos_ = 0;
flushToZlib(buf, Z_NO_FLUSH);
} else if (len > 0) {
if (uwbuf_.length - uwpos_ < len) {
flushToZlib(uwbuf_[0 .. uwpos_], Z_NO_FLUSH);
uwpos_ = 0;
}
uwbuf_[uwpos_ .. uwpos_ + len] = buf[];
uwpos_ += len;
}
}
override void flush() {
enforce(!outputFinished_, new TTransportException(
"flush() called after finish()", TTransportException.Type.BAD_ARGS));
flushToTransport(Z_SYNC_FLUSH);
}
override const(ubyte)[] borrow(ubyte* buf, size_t len) {
if (len <= readAvail) {
return urbuf_[urpos_ .. $];
}
return null;
}
override void consume(size_t len) {
enforce(readAvail >= len, new TTransportException(
"consume() did not follow a borrow().", TTransportException.Type.BAD_ARGS));
urpos_ += len;
}
/**
* Finalize the zlib stream.
*
* This causes zlib to flush any pending write data and write end-of-stream
* information, including the checksum. Once finish() has been called, no
* new data can be written to the stream.
*/
void finish() {
enforce(!outputFinished_, new TTransportException(
"flush() called on already finished TZlibTransport",
TTransportException.Type.BAD_ARGS));
flushToTransport(Z_FINISH);
}
/**
* Verify the checksum at the end of the zlib stream (by finish()).
*
* May only be called after all data has been read.
*
* Throws: TTransportException when the checksum is corrupted or there is
* still unread data left.
*/
void verifyChecksum() {
// If zlib has already reported the end of the stream, the checksum has
// been verified, no.
if (inputEnded_) return;
enforce(!readAvail, new TTransportException(
"verifyChecksum() called before end of zlib stream",
TTransportException.Type.CORRUPTED_DATA));
rstream_.next_out = urbuf_.ptr;
rstream_.avail_out = to!uint(urbuf_.length);
urpos_ = 0;
// readFromZlib() will throw an exception if the checksum is bad.
enforce(readFromZlib(), new TTransportException(
"checksum not available yet in verifyChecksum()",
TTransportException.Type.CORRUPTED_DATA));
enforce(inputEnded_, new TTransportException(
"verifyChecksum() called before end of zlib stream",
TTransportException.Type.CORRUPTED_DATA));
// If we get here, we are at the end of the stream and thus zlib has
// successfully verified the checksum.
}
private:
size_t readAvail() const @property {
return urbuf_.length - rstream_.avail_out - urpos_;
}
bool readFromZlib() {
assert(!inputEnded_);
if (rstream_.avail_in == 0) {
// zlib has used up all the compressed data we provided in crbuf, read
// some more from the underlying transport.
auto got = transport_.read(crbuf_);
if (got == 0) return false;
rstream_.next_in = crbuf_.ptr;
rstream_.avail_in = to!uint(got);
}
// We have some compressed data now, uncompress it.
auto zlib_result = inflate(rstream_, Z_SYNC_FLUSH);
if (zlib_result == Z_STREAM_END) {
inputEnded_ = true;
} else {
zlibEnforce(zlib_result, rstream_);
}
return true;
}
void flushToTransport(int type) {
// Compress remaining data in uwbuf_ to cwbuf_.
flushToZlib(uwbuf_[0 .. uwpos_], type);
uwpos_ = 0;
// Write all compressed data to the transport.
transport_.write(cwbuf_[0 .. $ - wstream_.avail_out]);
wstream_.next_out = cwbuf_.ptr;
wstream_.avail_out = to!uint(cwbuf_.length);
// Flush the transport.
transport_.flush();
}
void flushToZlib(in ubyte[] buf, int type) {
wstream_.next_in = cast(ubyte*)buf.ptr; // zlib only reads, cast is safe.
wstream_.avail_in = to!uint(buf.length);
while (true) {
if (type == Z_NO_FLUSH && wstream_.avail_in == 0) {
break;
}
if (wstream_.avail_out == 0) {
// cwbuf has been exhausted by zlib, flush to the underlying transport.
transport_.write(cwbuf_);
wstream_.next_out = cwbuf_.ptr;
wstream_.avail_out = to!uint(cwbuf_.length);
}
auto zlib_result = deflate(wstream_, type);
if (type == Z_FINISH && zlib_result == Z_STREAM_END) {
assert(wstream_.avail_in == 0);
outputFinished_ = true;
break;
}
zlibEnforce(zlib_result, wstream_);
if ((type == Z_SYNC_FLUSH || type == Z_FULL_FLUSH) &&
wstream_.avail_in == 0 && wstream_.avail_out != 0) {
break;
}
}
}
static void zlibEnforce(int status, z_stream* stream) {
if (status != Z_OK) {
throw new TZlibException(status, stream.msg);
}
}
static void zlibLogError(int status, z_stream* stream) {
if (status != Z_OK) {
logError("TZlibTransport: zlib failure in destructor: %s",
TZlibException.errorMessage(status, stream.msg));
}
}
// Writes smaller than this are buffered up (due to zlib handling overhead).
// Larger (or equal) writes are dumped straight to zlib.
enum MIN_DIRECT_DEFLATE_SIZE = 32;
TTransport transport_;
z_stream* rstream_;
z_stream* wstream_;
/// Whether zlib has reached the end of the input stream.
bool inputEnded_;
/// Whether the output stream was already finish()ed.
bool outputFinished_;
/// Compressed input data buffer.
ubyte[] crbuf_;
/// Uncompressed input data buffer.
ubyte[] urbuf_;
size_t urpos_;
/// Uncompressed output data buffer (where small writes are accumulated
/// before handing over to zlib).
ubyte[] uwbuf_;
size_t uwpos_;
/// Compressed output data buffer (filled by zlib, we flush it to the
/// underlying transport).
ubyte[] cwbuf_;
}
/**
* Wraps given transports into TZlibTransports.
*/
alias TWrapperTransportFactory!TZlibTransport TZlibTransportFactory;
/**
* An INTERNAL_ERROR-type TTransportException originating from an error
* signaled by zlib.
*/
class TZlibException : TTransportException {
this(int statusCode, const(char)* msg) {
super(errorMessage(statusCode, msg), TTransportException.Type.INTERNAL_ERROR);
zlibStatusCode = statusCode;
zlibMsg = msg ? to!string(msg) : "(null)";
}
int zlibStatusCode;
string zlibMsg;
static string errorMessage(int statusCode, const(char)* msg) {
string result = "zlib error: ";
if (msg) {
result ~= to!string(msg);
} else {
result ~= "(no message)";
}
result ~= " (status code = " ~ to!string(statusCode) ~ ")";
return result;
}
}
version (unittest) {
import std.exception : collectException;
import thrift.transport.memory;
}
// Make sure basic reading/writing works.
unittest {
auto buf = new TMemoryBuffer;
auto zlib = new TZlibTransport(buf);
immutable ubyte[] data = [1, 2, 3, 4, 5];
zlib.write(data);
zlib.finish();
auto result = new ubyte[data.length];
zlib.readAll(result);
enforce(data == result);
zlib.verifyChecksum();
}
// Make sure there is no data is written if write() is never called.
unittest {
auto buf = new TMemoryBuffer;
{
scope zlib = new TZlibTransport(buf);
}
enforce(buf.getContents().length == 0);
}
// Make sure calling write()/flush()/finish() again after finish() throws.
unittest {
auto buf = new TMemoryBuffer;
auto zlib = new TZlibTransport(buf);
zlib.write([1, 2, 3, 4, 5]);
zlib.finish();
auto ex = collectException!TTransportException(zlib.write([6]));
enforce(ex && ex.type == TTransportException.Type.BAD_ARGS);
ex = collectException!TTransportException(zlib.flush());
enforce(ex && ex.type == TTransportException.Type.BAD_ARGS);
ex = collectException!TTransportException(zlib.finish());
enforce(ex && ex.type == TTransportException.Type.BAD_ARGS);
}
// Make sure verifying the checksum works even if it requires starting a new
// reading buffer after reading the payload has already been completed.
unittest {
auto buf = new TMemoryBuffer;
auto zlib = new TZlibTransport(buf);
immutable ubyte[] data = [1, 2, 3, 4, 5];
zlib.write(data);
zlib.finish();
zlib = new TZlibTransport(buf, TZlibTransport.DEFAULT_URBUF_SIZE,
buf.getContents().length - 1); // The last byte belongs to the checksum.
auto result = new ubyte[data.length];
zlib.readAll(result);
enforce(data == result);
zlib.verifyChecksum();
}
// Make sure verifyChecksum() throws if we messed with the checksum.
unittest {
import std.stdio;
import thrift.transport.range;
auto buf = new TMemoryBuffer;
auto zlib = new TZlibTransport(buf);
immutable ubyte[] data = [1, 2, 3, 4, 5];
zlib.write(data);
zlib.finish();
void testCorrupted(const(ubyte)[] corruptedData) {
auto reader = new TZlibTransport(tInputRangeTransport(corruptedData));
auto result = new ubyte[data.length];
try {
reader.readAll(result);
// If it does read without complaining, the result should be correct.
enforce(result == data);
} catch (TZlibException e) {}
auto ex = collectException!TTransportException(reader.verifyChecksum());
enforce(ex && ex.type == TTransportException.Type.CORRUPTED_DATA);
}
testCorrupted(buf.getContents()[0 .. $ - 1]);
auto modified = buf.getContents().dup;
++modified[$ - 1];
testCorrupted(modified);
}

View file

@ -0,0 +1,212 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.util.awaitable;
import core.sync.condition;
import core.sync.mutex;
import core.time : Duration;
import std.exception : enforce;
import std.socket/+ : Socket, socketPair+/; // DMD @@BUG314@@
import thrift.base;
// To avoid DMD @@BUG6395@@.
import thrift.internal.algorithm;
/**
* An event that can occur at some point in the future and which can be
* awaited, either by blocking until it occurs, or by registering a callback
* delegate.
*/
interface TAwaitable {
/**
* Waits until the event occurs.
*
* Calling wait() for an event that has already occurred is a no-op.
*/
void wait();
/**
* Waits until the event occurs or the specified timeout expires.
*
* Calling wait() for an event that has already occurred is a no-op.
*
* Returns: Whether the event was triggered before the timeout expired.
*/
bool wait(Duration timeout);
/**
* Registers a callback that is called if the event occurs.
*
* The delegate will likely be invoked from a different thread, and is
* expected not to perform expensive work as it will usually be invoked
* synchronously by the notifying thread. The order in which registered
* callbacks are invoked is not specified.
*
* The callback must never throw, but nothrow semantics are difficult to
* enforce, so currently exceptions are just swallowed by
* TAwaitable implementations.
*
* If the event has already occurred, the delegate is immediately executed
* in the current thread.
*/
void addCallback(void delegate() dg);
/**
* Removes a previously added callback.
*
* Returns: Whether the callback could be found in the list, i.e. whether it
* was previously added.
*/
bool removeCallback(void delegate() dg);
}
/**
* A simple TAwaitable event triggered by just calling a trigger() method.
*/
class TOneshotEvent : TAwaitable {
this() {
mutex_ = new Mutex;
condition_ = new Condition(mutex_);
}
override void wait() {
synchronized (mutex_) {
while (!triggered_) condition_.wait();
}
}
override bool wait(Duration timeout) {
synchronized (mutex_) {
if (triggered_) return true;
condition_.wait(timeout);
return triggered_;
}
}
override void addCallback(void delegate() dg) {
mutex_.lock();
scope (failure) mutex_.unlock();
callbacks_ ~= dg;
if (triggered_) {
mutex_.unlock();
dg();
return;
}
mutex_.unlock();
}
override bool removeCallback(void delegate() dg) {
synchronized (mutex_) {
auto oldLength = callbacks_.length;
callbacks_ = removeEqual(callbacks_, dg);
return callbacks_.length < oldLength;
}
}
/**
* Triggers the event.
*
* Any registered event callbacks are executed synchronously before the
* function returns.
*/
void trigger() {
synchronized (mutex_) {
if (!triggered_) {
triggered_ = true;
condition_.notifyAll();
foreach (c; callbacks_) c();
}
}
}
private:
bool triggered_;
Mutex mutex_;
Condition condition_;
void delegate()[] callbacks_;
}
/**
* Translates TAwaitable events into dummy messages on a socket that can be
* used e.g. to wake up from a select() call.
*/
final class TSocketNotifier {
this() {
auto socks = socketPair();
foreach (s; socks) s.blocking = false;
sendSocket_ = socks[0];
recvSocket_ = socks[1];
}
/**
* The socket the messages will be sent to.
*/
Socket socket() @property {
return recvSocket_;
}
/**
* Atatches the socket notifier to the specified awaitable, causing it to
* write a byte to the notification socket when the awaitable callbacks are
* invoked.
*
* If the event has already been triggered, the dummy byte is written
* immediately to the socket.
*
* A socket notifier can only be attached to a single awaitable at a time.
*
* Throws: TException if the socket notifier is already attached.
*/
void attach(TAwaitable awaitable) {
enforce(!awaitable_, new TException("Already attached."));
awaitable.addCallback(&notify);
awaitable_ = awaitable;
}
/**
* Detaches the socket notifier from the awaitable it is currently attached
* to.
*
* Throws: TException if the socket notifier is not currently attached.
*/
void detach() {
enforce(awaitable_, new TException("Not attached."));
// Soak up any not currently read notification bytes.
ubyte[1] dummy = void;
while (recvSocket_.receive(dummy) != Socket.ERROR) {}
auto couldRemove = awaitable_.removeCallback(&notify);
assert(couldRemove);
awaitable_ = null;
}
private:
void notify() {
ubyte[1] zero;
sendSocket_.send(zero);
}
TAwaitable awaitable_;
Socket sendSocket_;
Socket recvSocket_;
}

View file

@ -0,0 +1,105 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.util.cancellation;
import core.atomic;
import thrift.base;
import thrift.util.awaitable;
/**
* A cancellation request for asynchronous or blocking synchronous operations.
*
* It is passed to the entity creating an operation, which will usually monitor
* it either by polling or by adding event handlers, and cancel the operation
* if it is triggered.
*
* For synchronous operations, this usually means either throwing a
* TCancelledException or immediately returning, depending on whether
* cancellation is an expected part of the task outcome or not. For
* asynchronous operations, cancellation typically entails stopping background
* work and cancelling a result future, if not already completed.
*
* An operation accepting a TCancellation does not need to guarantee that it
* will actually be able to react to the cancellation request.
*/
interface TCancellation {
/**
* Whether the cancellation request has been triggered.
*/
bool triggered() const @property;
/**
* Throws a TCancelledException if the cancellation request has already been
* triggered.
*/
void throwIfTriggered() const;
/**
* A TAwaitable that can be used to wait for cancellation triggering.
*/
TAwaitable triggering() @property;
}
/**
* The origin of a cancellation request, which provides a way to actually
* trigger it.
*
* This design allows operations to pass the TCancellation on to sub-tasks,
* while making sure that the cancellation can only be triggered by the
* »outermost« instance waiting for the result.
*/
final class TCancellationOrigin : TCancellation {
this() {
event_ = new TOneshotEvent;
}
/**
* Triggers the cancellation request.
*/
void trigger() {
atomicStore(triggered_, true);
event_.trigger();
}
/+override+/ bool triggered() const @property {
return atomicLoad(triggered_);
}
/+override+/ void throwIfTriggered() const {
if (triggered) throw new TCancelledException;
}
/+override+/ TAwaitable triggering() @property {
return event_;
}
private:
shared bool triggered_;
TOneshotEvent event_;
}
///
class TCancelledException : TException {
///
this(string msg = null, string file = __FILE__, size_t line = __LINE__,
Throwable next = null
) {
super(msg ? msg : "The operation has been cancelled.", file, line, next);
}
}

View file

@ -0,0 +1,549 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.util.future;
import core.atomic;
import core.sync.condition;
import core.sync.mutex;
import core.time : Duration;
import std.array : empty, front, popFront;
import std.conv : to;
import std.exception : enforce;
import std.traits : BaseTypeTuple, isSomeFunction, ParameterTypeTuple, ReturnType;
import thrift.base;
import thrift.util.awaitable;
import thrift.util.cancellation;
/**
* Represents an operation which is executed asynchronously and the result of
* which will become available at some point in the future.
*
* Once a operation is completed, the result of the operation can be fetched
* via the get() family of methods. There are three possible cases: Either the
* operation succeeded, then its return value is returned, or it failed by
* throwing, in which case the exception is rethrown, or it was cancelled
* before, then a TCancelledException is thrown. There might be TFuture
* implementations which never possibly enter the cancelled state.
*
* All methods are thread-safe, but keep in mind that any exception object or
* result (if it is a reference type, of course) is shared between all
* get()-family invocations.
*/
interface TFuture(ResultType) {
/**
* The status the operation is currently in.
*
* An operation starts out in RUNNING status, and changes state to one of the
* others at most once afterwards.
*/
TFutureStatus status() @property;
/**
* A TAwaitable triggered when the operation leaves the RUNNING status.
*/
TAwaitable completion() @property;
/**
* Convenience shorthand for waiting until the result is available and then
* get()ing it.
*
* If the operation has already completed, the result is immediately
* returned.
*
* The result of this method is »alias this«'d to the interface, so that
* TFuture can be used as a drop-in replacement for a simple value in
* synchronous code.
*/
final ResultType waitGet() {
completion.wait();
return get();
}
final @property auto waitGetProperty() { return waitGet(); }
alias waitGetProperty this;
/**
* Convenience shorthand for waiting until the result is available and then
* get()ing it.
*
* If the operation completes in time, returns its result (resp. throws an
* exception for the failed/cancelled cases). If not, throws a
* TFutureException.
*/
final ResultType waitGet(Duration timeout) {
enforce(completion.wait(timeout), new TFutureException(
"Operation did not complete in time."));
return get();
}
/**
* Returns the result of the operation.
*
* Throws: TFutureException if the operation has been cancelled,
* TCancelledException if it is not yet done; the set exception if it
* failed.
*/
ResultType get();
/**
* Returns the captured exception if the operation failed, or null otherwise.
*
* Throws: TFutureException if not yet done, TCancelledException if the
* operation has been cancelled.
*/
Exception getException();
}
/**
* The states the operation offering a future interface can be in.
*/
enum TFutureStatus : byte {
RUNNING, /// The operation is still running.
SUCCEEDED, /// The operation completed without throwing an exception.
FAILED, /// The operation completed by throwing an exception.
CANCELLED /// The operation was cancelled.
}
/**
* A TFuture covering the simple but common case where the result is simply
* set by a call to succeed()/fail().
*
* All methods are thread-safe, but usually, succeed()/fail() are only called
* from a single thread (different from the thread(s) waiting for the result
* using the TFuture interface, though).
*/
class TPromise(ResultType) : TFuture!ResultType {
this() {
statusMutex_ = new Mutex;
completionEvent_ = new TOneshotEvent;
}
override S status() const @property {
return atomicLoad(status_);
}
override TAwaitable completion() @property {
return completionEvent_;
}
override ResultType get() {
auto s = atomicLoad(status_);
enforce(s != S.RUNNING,
new TFutureException("Operation not yet completed."));
if (s == S.CANCELLED) throw new TCancelledException;
if (s == S.FAILED) throw exception_;
static if (!is(ResultType == void)) {
return result_;
}
}
override Exception getException() {
auto s = atomicLoad(status_);
enforce(s != S.RUNNING,
new TFutureException("Operation not yet completed."));
if (s == S.CANCELLED) throw new TCancelledException;
if (s == S.SUCCEEDED) return null;
return exception_;
}
static if (!is(ResultType == void)) {
/**
* Sets the result of the operation, marks it as done, and notifies any
* waiters.
*
* If the operation has been cancelled before, nothing happens.
*
* Throws: TFutureException if the operation is already completed.
*/
void succeed(ResultType result) {
synchronized (statusMutex_) {
auto s = atomicLoad(status_);
if (s == S.CANCELLED) return;
enforce(s == S.RUNNING,
new TFutureException("Operation already completed."));
result_ = result;
atomicStore(status_, S.SUCCEEDED);
}
completionEvent_.trigger();
}
} else {
void succeed() {
synchronized (statusMutex_) {
auto s = atomicLoad(status_);
if (s == S.CANCELLED) return;
enforce(s == S.RUNNING,
new TFutureException("Operation already completed."));
atomicStore(status_, S.SUCCEEDED);
}
completionEvent_.trigger();
}
}
/**
* Marks the operation as failed with the specified exception and notifies
* any waiters.
*
* If the operation was already cancelled, nothing happens.
*
* Throws: TFutureException if the operation is already completed.
*/
void fail(Exception exception) {
synchronized (statusMutex_) {
auto status = atomicLoad(status_);
if (status == S.CANCELLED) return;
enforce(status == S.RUNNING,
new TFutureException("Operation already completed."));
exception_ = exception;
atomicStore(status_, S.FAILED);
}
completionEvent_.trigger();
}
/**
* Marks this operation as completed and takes over the outcome of another
* TFuture of the same type.
*
* If this operation was already cancelled, nothing happens. If the other
* operation was cancelled, this operation is marked as failed with a
* TCancelledException.
*
* Throws: TFutureException if the passed in future was not completed or
* this operation is already completed.
*/
void complete(TFuture!ResultType future) {
synchronized (statusMutex_) {
auto status = atomicLoad(status_);
if (status == S.CANCELLED) return;
enforce(status == S.RUNNING,
new TFutureException("Operation already completed."));
enforce(future.status != S.RUNNING, new TFutureException(
"The passed TFuture is not yet completed."));
status = future.status;
if (status == S.CANCELLED) {
status = S.FAILED;
exception_ = new TCancelledException;
} else if (status == S.FAILED) {
exception_ = future.getException();
} else static if (!is(ResultType == void)) {
result_ = future.get();
}
atomicStore(status_, status);
}
completionEvent_.trigger();
}
/**
* Marks this operation as cancelled and notifies any waiters.
*
* If the operation is already completed, nothing happens.
*/
void cancel() {
synchronized (statusMutex_) {
auto status = atomicLoad(status_);
if (status == S.RUNNING) atomicStore(status_, S.CANCELLED);
}
completionEvent_.trigger();
}
private:
// Convenience alias because TFutureStatus is ubiquitous in this class.
alias TFutureStatus S;
// The status the promise is currently in.
shared S status_;
union {
static if (!is(ResultType == void)) {
// Set if status_ is SUCCEEDED.
ResultType result_;
}
// Set if status_ is FAILED.
Exception exception_;
}
// Protects status_.
// As for result_ and exception_: They are only set once, while status_ is
// still RUNNING, so given that the operation has already completed, reading
// them is safe without holding some kind of lock.
Mutex statusMutex_;
// Triggered when the event completes.
TOneshotEvent completionEvent_;
}
///
class TFutureException : TException {
///
this(string msg = "", string file = __FILE__, size_t line = __LINE__,
Throwable next = null)
{
super(msg, file, line, next);
}
}
/**
* Creates an interface that is similar to a given one, but accepts an
* additional, optional TCancellation parameter each method, and returns
* TFutures instead of plain return values.
*
* For example, given the following declarations:
* ---
* interface Foo {
* void bar();
* string baz(int a);
* }
* alias TFutureInterface!Foo FutureFoo;
* ---
*
* FutureFoo would be equivalent to:
* ---
* interface FutureFoo {
* TFuture!void bar(TCancellation cancellation = null);
* TFuture!string baz(int a, TCancellation cancellation = null);
* }
* ---
*/
template TFutureInterface(Interface) if (is(Interface _ == interface)) {
mixin({
string code = "interface TFutureInterface \n";
static if (is(Interface Bases == super) && Bases.length > 0) {
code ~= ": ";
foreach (i; 0 .. Bases.length) {
if (i > 0) code ~= ", ";
code ~= "TFutureInterface!(BaseTypeTuple!Interface[" ~ to!string(i) ~ "]) ";
}
}
code ~= "{\n";
foreach (methodName; __traits(derivedMembers, Interface)) {
enum qn = "Interface." ~ methodName;
static if (isSomeFunction!(mixin(qn))) {
code ~= "TFuture!(ReturnType!(" ~ qn ~ ")) " ~ methodName ~
"(ParameterTypeTuple!(" ~ qn ~ "), TCancellation cancellation = null);\n";
}
}
code ~= "}\n";
return code;
}());
}
/**
* An input range that aggregates results from multiple asynchronous operations,
* returning them in the order they arrive.
*
* Additionally, a timeout can be set after which results from not yet finished
* futures will no longer be waited for, e.g. to ensure the time it takes to
* iterate over a set of results is limited.
*/
final class TFutureAggregatorRange(T) {
/**
* Constructs a new instance.
*
* Params:
* futures = The set of futures to collect results from.
* timeout = If positive, not yet finished futures will be cancelled and
* their results will not be taken into account.
*/
this(TFuture!T[] futures, TCancellationOrigin childCancellation,
Duration timeout = dur!"hnsecs"(0)
) {
if (timeout > dur!"hnsecs"(0)) {
timeoutSysTick_ = TickDuration.currSystemTick +
TickDuration.from!"hnsecs"(timeout.total!"hnsecs");
} else {
timeoutSysTick_ = TickDuration(0);
}
queueMutex_ = new Mutex;
queueNonEmptyCondition_ = new Condition(queueMutex_);
futures_ = futures;
childCancellation_ = childCancellation;
foreach (future; futures_) {
future.completion.addCallback({
auto f = future;
return {
if (f.status == TFutureStatus.CANCELLED) return;
assert(f.status != TFutureStatus.RUNNING);
synchronized (queueMutex_) {
completedQueue_ ~= f;
if (completedQueue_.length == 1) {
queueNonEmptyCondition_.notifyAll();
}
}
};
}());
}
}
/**
* Whether the range is empty.
*
* This is the case if the results from the completed futures not having
* failed have already been popped and either all future have been finished
* or the timeout has expired.
*
* Potentially blocks until a new result is available or the timeout has
* expired.
*/
bool empty() @property {
if (finished_) return true;
if (bufferFilled_) return false;
while (true) {
TFuture!T future;
synchronized (queueMutex_) {
// The while loop is just being cautious about spurious wakeups, in
// case they should be possible.
while (completedQueue_.empty) {
auto remaining = to!Duration(timeoutSysTick_ -
TickDuration.currSystemTick);
if (remaining <= dur!"hnsecs"(0)) {
// No time left, but still no element received we are empty now.
finished_ = true;
childCancellation_.trigger();
return true;
}
queueNonEmptyCondition_.wait(remaining);
}
future = completedQueue_.front;
completedQueue_.popFront();
}
++completedCount_;
if (completedCount_ == futures_.length) {
// This was the last future in the list, there is no possibility
// another result could ever become available.
finished_ = true;
}
if (future.status == TFutureStatus.FAILED) {
// This one failed, loop again and try getting another item from
// the queue.
exceptions_ ~= future.getException();
} else {
resultBuffer_ = future.get();
bufferFilled_ = true;
return false;
}
}
}
/**
* Returns the first element from the range.
*
* Potentially blocks until a new result is available or the timeout has
* expired.
*
* Throws: TException if the range is empty.
*/
T front() {
enforce(!empty, new TException(
"Cannot get front of an empty future aggregator range."));
return resultBuffer_;
}
/**
* Removes the first element from the range.
*
* Potentially blocks until a new result is available or the timeout has
* expired.
*
* Throws: TException if the range is empty.
*/
void popFront() {
enforce(!empty, new TException(
"Cannot pop front of an empty future aggregator range."));
bufferFilled_ = false;
}
/**
* The number of futures the result of which has been returned or which have
* failed so far.
*/
size_t completedCount() @property const {
return completedCount_;
}
/**
* The exceptions collected from failed TFutures so far.
*/
Exception[] exceptions() @property {
return exceptions_;
}
private:
TFuture!T[] futures_;
TCancellationOrigin childCancellation_;
// The system tick this operation will time out, or zero if no timeout has
// been set.
TickDuration timeoutSysTick_;
bool finished_;
bool bufferFilled_;
T resultBuffer_;
Exception[] exceptions_;
size_t completedCount_;
// The queue of completed futures. This (and the associated condition) are
// the only parts of this class that are accessed by multiple threads.
TFuture!T[] completedQueue_;
Mutex queueMutex_;
Condition queueNonEmptyCondition_;
}
/**
* TFutureAggregatorRange construction helper to avoid having to explicitly
* specify the value type, i.e. to allow the constructor being called using IFTI
* (see $(DMDBUG 6082, D Bugzilla enhancement requet 6082)).
*/
TFutureAggregatorRange!T tFutureAggregatorRange(T)(TFuture!T[] futures,
TCancellationOrigin childCancellation, Duration timeout = dur!"hnsecs"(0)
) {
return new TFutureAggregatorRange!T(futures, childCancellation, timeout);
}

View file

@ -0,0 +1,146 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.util.hashset;
import std.algorithm : joiner, map;
import std.conv : to;
import std.traits : isImplicitlyConvertible, ParameterTypeTuple;
import std.range : ElementType, isInputRange;
struct Void {}
/**
* A quickly hacked together hash set implementation backed by built-in
* associative arrays to have something to compile Thrift's set<> to until
* std.container gains something suitable.
*/
// Note: The funky pointer casts (i.e. *(cast(immutable(E)*)&e) instead of
// just cast(immutable(E))e) are a workaround for LDC 2 compatibility.
final class HashSet(E) {
///
this() {}
///
this(E[] elems...) {
insert(elems);
}
///
void insert(Stuff)(Stuff stuff) if (isImplicitlyConvertible!(Stuff, E)) {
aa_[*(cast(immutable(E)*)&stuff)] = Void.init;
}
///
void insert(Stuff)(Stuff stuff) if (
isInputRange!Stuff && isImplicitlyConvertible!(ElementType!Stuff, E)
) {
foreach (e; stuff) {
aa_[*(cast(immutable(E)*)&e)] = Void.init;
}
}
///
void opOpAssign(string op : "~", Stuff)(Stuff stuff) {
insert(stuff);
}
///
void remove(E e) {
aa_.remove(*(cast(immutable(E)*)&e));
}
alias remove removeKey;
///
void removeAll() {
aa_ = null;
}
///
size_t length() @property const {
return aa_.length;
}
///
size_t empty() @property const {
return !aa_.length;
}
///
bool opBinaryRight(string op : "in")(E e) const {
return (e in aa_) !is null;
}
///
auto opSlice() const {
// TODO: Implement using AA key range once available in release DMD/druntime
// to avoid allocation.
return cast(E[])(aa_.keys);
}
///
override string toString() const {
// Only provide toString() if to!string() is available for E (exceptions are
// e.g. delegates).
static if (is(typeof(to!string(E.init)) : string)) {
return "{" ~ to!string(joiner(map!`to!string(a)`(aa_.keys), ", ")) ~ "}";
} else {
// Cast to work around Object not being const-correct.
return (cast()super).toString();
}
}
///
override bool opEquals(Object other) const {
auto rhs = cast(const(HashSet))other;
if (rhs) {
return aa_ == rhs.aa_;
}
// Cast to work around Object not being const-correct.
return (cast()super).opEquals(other);
}
private:
Void[immutable(E)] aa_;
}
/// Ditto
auto hashSet(E)(E[] elems...) {
return new HashSet!E(elems);
}
unittest {
import std.exception;
auto a = hashSet(1, 2, 2, 3);
enforce(a.length == 3);
enforce(2 in a);
enforce(5 !in a);
enforce(a.toString().length == 9);
a.remove(2);
enforce(a.length == 2);
enforce(2 !in a);
a.removeAll();
enforce(a.empty);
enforce(a.toString() == "{}");
void delegate() dg;
auto b = hashSet(dg);
static assert(__traits(compiles, b.toString()));
}

129
vendor/git.apache.org/thrift.git/lib/d/test/Makefile.am generated vendored Executable file
View file

@ -0,0 +1,129 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
AUTOMAKE_OPTIONS = serial-tests
BUILT_SOURCES = trusted-ca-certificate.pem server-certificate.pem
# Thrift compiler rules
THRIFT = $(top_builddir)/compiler/cpp/thrift
debug_proto_gen = $(addprefix gen-d/, DebugProtoTest_types.d)
$(debug_proto_gen): $(top_srcdir)/test/DebugProtoTest.thrift
$(THRIFT) --gen d -nowarn $<
stress_test_gen = $(addprefix gen-d/thrift/test/stress/, Service.d \
StressTest_types.d)
$(stress_test_gen): $(top_srcdir)/test/StressTest.thrift
$(THRIFT) --gen d $<
thrift_test_gen = $(addprefix gen-d/thrift/test/, SecondService.d \
ThriftTest.d ThriftTest_constants.d ThriftTest_types.d)
$(thrift_test_gen): $(top_srcdir)/test/ThriftTest.thrift
$(THRIFT) --gen d $<
# The actual test targets.
# There just must be some way to reassign a variable without warnings in
# Automake...
targets__ = async_test client_pool_test serialization_benchmark \
stress_test_server thrift_test_client thrift_test_server transport_test
ran_tests__ = client_pool_test \
transport_test \
async_test_runner.sh \
thrift_test_runner.sh
libevent_dependent_targets = async_test_client client_pool_test \
stress_test_server thrift_test_server
libevent_dependent_ran_tests = client_pool_test async_test_runner.sh thrift_test_runner.sh
openssl_dependent_targets = async_test thrift_test_client thrift_test_server
openssl_dependent_ran_tests = async_test_runner.sh thrift_test_runner.sh
d_test_flags =
if WITH_D_EVENT_TESTS
d_test_flags += $(DMD_LIBEVENT_FLAGS) ../$(D_EVENT_LIB_NAME)
targets_ = $(targets__)
ran_tests_ = $(ran_tests__)
else
targets_ = $(filter-out $(libevent_dependent_targets), $(targets__))
ran_tests_ = $(filter-out $(libevent_dependent_ran_tests), $(ran_tests__))
endif
if WITH_D_SSL_TESTS
d_test_flags += $(DMD_OPENSSL_FLAGS) ../$(D_SSL_LIB_NAME)
targets = $(targets_)
ran_tests = $(ran_tests_)
else
targets = $(filter-out $(openssl_dependent_targets), $(targets_))
ran_tests = $(filter-out $(openssl_dependent_ran_tests), $(ran_tests_))
endif
d_test_flags += -w -wi -O -release -inline -I$(top_srcdir)/lib/d/src -Igen-d \
$(top_builddir)/lib/d/$(D_LIB_NAME)
async_test client_pool_test transport_test: %: %.d
$(DMD) $(d_test_flags) -of$@ $^
serialization_benchmark: %: %.d $(debug_proto_gen)
$(DMD) $(d_test_flags) -of$@ $^
stress_test_server: %: %.d test_utils.d $(stress_test_gen)
$(DMD) $(d_test_flags) -of$@ $^
thrift_test_client: %: %.d thrift_test_common.d $(thrift_test_gen)
$(DMD) $(d_test_flags) -of$@ $^
thrift_test_server: %: %.d thrift_test_common.d test_utils.d $(thrift_test_gen)
$(DMD) $(d_test_flags) -of$@ $^
# Certificate generation targets (for the SSL tests).
# Currently, we just assume that the "openssl" tool is on the path, could be
# replaced by a more elaborate mechanism.
server-certificate.pem: openssl.test.cnf
openssl req -new -x509 -nodes -config openssl.test.cnf \
-out server-certificate.pem
trusted-ca-certificate.pem: server-certificate.pem
cat server-certificate.pem > $@
check-local: $(targets)
clean-local:
$(RM) -rf gen-d $(targets) $(addsuffix .o, $(targets))
# Tests ran as part of make check.
async_test_runner.sh: async_test trusted-ca-certificate.pem server-certificate.pem
thrift_test_runner.sh: thrift_test_client thrift_test_server \
trusted-ca-certificate.pem server-certificate.pem
TESTS = $(ran_tests)
precross: $(targets)

View file

@ -0,0 +1,396 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless enforced by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module async_test;
import core.atomic;
import core.sync.condition : Condition;
import core.sync.mutex : Mutex;
import core.thread : dur, Thread, ThreadGroup;
import std.conv : text;
import std.datetime;
import std.getopt;
import std.exception : collectException, enforce;
import std.parallelism : TaskPool;
import std.stdio;
import std.string;
import std.variant : Variant;
import thrift.base;
import thrift.async.base;
import thrift.async.libevent;
import thrift.async.socket;
import thrift.async.ssl;
import thrift.codegen.async_client;
import thrift.codegen.async_client_pool;
import thrift.codegen.base;
import thrift.codegen.processor;
import thrift.protocol.base;
import thrift.protocol.binary;
import thrift.server.base;
import thrift.server.simple;
import thrift.server.transport.socket;
import thrift.server.transport.ssl;
import thrift.transport.base;
import thrift.transport.buffered;
import thrift.transport.ssl;
import thrift.util.cancellation;
version (Posix) {
import core.stdc.signal;
import core.sys.posix.signal;
// Disable SIGPIPE because SSL server will write to broken socket after
// client disconnected (see TSSLSocket docs).
shared static this() {
signal(SIGPIPE, SIG_IGN);
}
}
interface AsyncTest {
string echo(string value);
string delayedEcho(string value, long milliseconds);
void fail(string reason);
void delayedFail(string reason, long milliseconds);
enum methodMeta = [
TMethodMeta("fail", [], [TExceptionMeta("ate", 1, "AsyncTestException")]),
TMethodMeta("delayedFail", [], [TExceptionMeta("ate", 1, "AsyncTestException")])
];
alias .AsyncTestException AsyncTestException;
}
class AsyncTestException : TException {
string reason;
mixin TStructHelpers!();
}
void main(string[] args) {
ushort port = 9090;
ushort managerCount = 2;
ushort serversPerManager = 5;
ushort threadsPerServer = 10;
uint iterations = 10;
bool ssl;
bool trace;
getopt(args,
"iterations", &iterations,
"managers", &managerCount,
"port", &port,
"servers-per-manager", &serversPerManager,
"ssl", &ssl,
"threads-per-server", &threadsPerServer,
"trace", &trace,
);
TTransportFactory clientTransportFactory;
TSSLContext serverSSLContext;
if (ssl) {
auto clientSSLContext = new TSSLContext();
with (clientSSLContext) {
ciphers = "ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH";
authenticate = true;
loadTrustedCertificates("./trusted-ca-certificate.pem");
}
clientTransportFactory = new TAsyncSSLSocketFactory(clientSSLContext);
serverSSLContext = new TSSLContext();
with (serverSSLContext) {
serverSide = true;
loadCertificate("./server-certificate.pem");
loadPrivateKey("./server-private-key.pem");
ciphers = "ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH";
}
} else {
clientTransportFactory = new TBufferedTransportFactory;
}
auto serverCancel = new TCancellationOrigin;
scope(exit) {
writeln("Triggering server shutdown...");
serverCancel.trigger();
writeln("done.");
}
auto managers = new TLibeventAsyncManager[managerCount];
scope (exit) foreach (ref m; managers) destroy(m);
auto clientsThreads = new ThreadGroup;
foreach (managerIndex, ref manager; managers) {
manager = new TLibeventAsyncManager;
foreach (serverIndex; 0 .. serversPerManager) {
auto currentPort = cast(ushort)
(port + managerIndex * serversPerManager + serverIndex);
// Start the server and wait until it is up and running.
auto servingMutex = new Mutex;
auto servingCondition = new Condition(servingMutex);
auto handler = new PreServeNotifyHandler(servingMutex, servingCondition);
synchronized (servingMutex) {
(new ServerThread!TSimpleServer(currentPort, serverSSLContext, trace,
serverCancel, handler)).start();
servingCondition.wait();
}
// We only run the timing tests for the first server on each async
// manager, so that we don't get spurious timing errors becaue of
// ordering issues.
auto runTimingTests = (serverIndex == 0);
auto c = new ClientsThread(manager, currentPort, clientTransportFactory,
threadsPerServer, iterations, runTimingTests, trace);
clientsThreads.add(c);
c.start();
}
}
clientsThreads.joinAll();
}
class AsyncTestHandler : AsyncTest {
this(bool trace) {
trace_ = trace;
}
override string echo(string value) {
if (trace_) writefln(`echo("%s")`, value);
return value;
}
override string delayedEcho(string value, long milliseconds) {
if (trace_) writef(`delayedEcho("%s", %s ms)... `, value, milliseconds);
Thread.sleep(dur!"msecs"(milliseconds));
if (trace_) writeln("returning.");
return value;
}
override void fail(string reason) {
if (trace_) writefln(`fail("%s")`, reason);
auto ate = new AsyncTestException;
ate.reason = reason;
throw ate;
}
override void delayedFail(string reason, long milliseconds) {
if (trace_) writef(`delayedFail("%s", %s ms)... `, reason, milliseconds);
Thread.sleep(dur!"msecs"(milliseconds));
if (trace_) writeln("returning.");
auto ate = new AsyncTestException;
ate.reason = reason;
throw ate;
}
private:
bool trace_;
AsyncTestException ate_;
}
class PreServeNotifyHandler : TServerEventHandler {
this(Mutex servingMutex, Condition servingCondition) {
servingMutex_ = servingMutex;
servingCondition_ = servingCondition;
}
void preServe() {
synchronized (servingMutex_) {
servingCondition_.notifyAll();
}
}
Variant createContext(TProtocol input, TProtocol output) { return Variant.init; }
void deleteContext(Variant serverContext, TProtocol input, TProtocol output) {}
void preProcess(Variant serverContext, TTransport transport) {}
private:
Mutex servingMutex_;
Condition servingCondition_;
}
class ServerThread(ServerType) : Thread {
this(ushort port, TSSLContext sslContext, bool trace,
TCancellation cancellation, TServerEventHandler eventHandler
) {
port_ = port;
sslContext_ = sslContext;
trace_ = trace;
cancellation_ = cancellation;
eventHandler_ = eventHandler;
super(&run);
}
void run() {
TServerSocket serverSocket;
if (sslContext_) {
serverSocket = new TSSLServerSocket(port_, sslContext_);
} else {
serverSocket = new TServerSocket(port_);
}
auto transportFactory = new TBufferedTransportFactory;
auto protocolFactory = new TBinaryProtocolFactory!();
auto processor = new TServiceProcessor!AsyncTest(new AsyncTestHandler(trace_));
auto server = new ServerType(processor, serverSocket, transportFactory,
protocolFactory);
server.eventHandler = eventHandler_;
writefln("Starting server on port %s...", port_);
server.serve(cancellation_);
writefln("Server thread on port %s done.", port_);
}
private:
ushort port_;
bool trace_;
TCancellation cancellation_;
TSSLContext sslContext_;
TServerEventHandler eventHandler_;
}
class ClientsThread : Thread {
this(TAsyncSocketManager manager, ushort port, TTransportFactory tf,
ushort threads, uint iterations, bool runTimingTests, bool trace
) {
manager_ = manager;
port_ = port;
transportFactory_ = tf;
threads_ = threads;
iterations_ = iterations;
runTimingTests_ = runTimingTests;
trace_ = trace;
super(&run);
}
void run() {
auto transport = new TAsyncSocket(manager_, "localhost", port_);
{
auto client = new TAsyncClient!AsyncTest(
transport,
transportFactory_,
new TBinaryProtocolFactory!()
);
transport.open();
auto clientThreads = new ThreadGroup;
foreach (clientId; 0 .. threads_) {
clientThreads.create({
auto c = clientId;
return {
foreach (i; 0 .. iterations_) {
immutable id = text(port_, ":", c, ":", i);
{
if (trace_) writefln(`Calling echo("%s")... `, id);
auto a = client.echo(id);
enforce(a == id);
if (trace_) writefln(`echo("%s") done.`, id);
}
{
if (trace_) writefln(`Calling fail("%s")... `, id);
auto a = cast(AsyncTestException)collectException(client.fail(id).waitGet());
enforce(a && a.reason == id);
if (trace_) writefln(`fail("%s") done.`, id);
}
}
};
}());
}
clientThreads.joinAll();
transport.close();
}
if (runTimingTests_) {
auto client = new TAsyncClient!AsyncTest(
transport,
transportFactory_,
new TBinaryProtocolFactory!TBufferedTransport
);
// Temporarily redirect error logs to stdout, as SSL errors on the server
// side are expected when the client terminates aburptly (as is the case
// in the timeout test).
auto oldErrorLogSink = g_errorLogSink;
g_errorLogSink = g_infoLogSink;
scope (exit) g_errorLogSink = oldErrorLogSink;
foreach (i; 0 .. iterations_) {
transport.open();
immutable id = text(port_, ":", i);
{
if (trace_) writefln(`Calling delayedEcho("%s", 100 ms)...`, id);
auto a = client.delayedEcho(id, 100);
enforce(!a.completion.wait(dur!"usecs"(1)),
text("wait() succeeded early (", a.get(), ", ", id, ")."));
enforce(!a.completion.wait(dur!"usecs"(1)),
text("wait() succeeded early (", a.get(), ", ", id, ")."));
enforce(a.completion.wait(dur!"msecs"(200)),
text("wait() didn't succeed as expected (", id, ")."));
enforce(a.get() == id);
if (trace_) writefln(`... delayedEcho("%s") done.`, id);
}
{
if (trace_) writefln(`Calling delayedFail("%s", 100 ms)... `, id);
auto a = client.delayedFail(id, 100);
enforce(!a.completion.wait(dur!"usecs"(1)),
text("wait() succeeded early (", id, ", ", collectException(a.get()), ")."));
enforce(!a.completion.wait(dur!"usecs"(1)),
text("wait() succeeded early (", id, ", ", collectException(a.get()), ")."));
enforce(a.completion.wait(dur!"msecs"(200)),
text("wait() didn't succeed as expected (", id, ")."));
auto e = cast(AsyncTestException)collectException(a.get());
enforce(e && e.reason == id);
if (trace_) writefln(`... delayedFail("%s") done.`, id);
}
{
transport.recvTimeout = dur!"msecs"(50);
if (trace_) write(`Calling delayedEcho("socketTimeout", 100 ms)... `);
auto a = client.delayedEcho("socketTimeout", 100);
auto e = cast(TTransportException)collectException(a.waitGet());
enforce(e, text("Operation didn't fail as expected (", id, ")."));
enforce(e.type == TTransportException.Type.TIMED_OUT,
text("Wrong timeout exception type (", id, "): ", e));
if (trace_) writeln(`timed out as expected.`);
// Wait until the server thread reset before the next iteration.
Thread.sleep(dur!"msecs"(50));
transport.recvTimeout = dur!"hnsecs"(0);
}
transport.close();
}
}
writefln("Clients thread for port %s done.", port_);
}
TAsyncSocketManager manager_;
ushort port_;
TTransportFactory transportFactory_;
ushort threads_;
uint iterations_;
bool runTimingTests_;
bool trace_;
}

View file

@ -0,0 +1,28 @@
#!/bin/bash
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
CUR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
# Runs the async test in both SSL and non-SSL mode.
${CUR}/async_test > /dev/null || exit 1
echo "Non-SSL tests done."
${CUR}/async_test --ssl > /dev/null || exit 1
echo "SSL tests done."

View file

@ -0,0 +1,416 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module client_pool_test;
import core.time : Duration, dur;
import core.thread : Thread;
import std.algorithm;
import std.array;
import std.conv;
import std.exception;
import std.getopt;
import std.range;
import std.stdio;
import std.typecons;
import thrift.base;
import thrift.async.libevent;
import thrift.async.socket;
import thrift.codegen.base;
import thrift.codegen.async_client;
import thrift.codegen.async_client_pool;
import thrift.codegen.client;
import thrift.codegen.client_pool;
import thrift.codegen.processor;
import thrift.protocol.binary;
import thrift.server.simple;
import thrift.server.transport.socket;
import thrift.transport.buffered;
import thrift.transport.socket;
import thrift.util.cancellation;
import thrift.util.future;
// We use this as our RPC-layer exception here to make sure socket/… problems
// (that would usually considered to be RPC layer faults) cause the tests to
// fail, even though we are testing the RPC exception handling.
class TestServiceException : TException {
int port;
}
interface TestService {
int getPort();
alias .TestServiceException TestServiceException;
enum methodMeta = [TMethodMeta("getPort", [],
[TExceptionMeta("a", 1, "TestServiceException")])];
}
// Use some derived service, just to check that the pools handle inheritance
// correctly.
interface ExTestService : TestService {
int[] getPortInArray();
enum methodMeta = [TMethodMeta("getPortInArray", [],
[TExceptionMeta("a", 1, "TestServiceException")])];
}
class ExTestHandler : ExTestService {
this(ushort port, Duration delay, bool failing, bool trace) {
this.port = port;
this.delay = delay;
this.failing = failing;
this.trace = trace;
}
override int getPort() {
if (trace) {
stderr.writefln("getPort() called on %s (delay: %s, failing: %s)", port,
delay, failing);
}
sleep();
failIfEnabled();
return port;
}
override int[] getPortInArray() {
return [getPort()];
}
ushort port;
Duration delay;
bool failing;
bool trace;
private:
void sleep() {
if (delay > dur!"hnsecs"(0)) Thread.sleep(delay);
}
void failIfEnabled() {
if (!failing) return;
auto e = new TestServiceException;
e.port = port;
throw e;
}
}
class ServerThread : Thread {
this(ExTestHandler handler, TCancellation cancellation) {
super(&run);
handler_ = handler;
cancellation_ = cancellation;
}
private:
void run() {
try {
auto protocolFactory = new TBinaryProtocolFactory!();
auto processor = new TServiceProcessor!ExTestService(handler_);
auto serverTransport = new TServerSocket(handler_.port);
serverTransport.recvTimeout = dur!"seconds"(3);
auto transportFactory = new TBufferedTransportFactory;
auto server = new TSimpleServer(
processor, serverTransport, transportFactory, protocolFactory);
server.serve(cancellation_);
} catch (Exception e) {
writefln("Server thread on port %s failed: %s", handler_.port, e);
}
}
TCancellation cancellation_;
ExTestHandler handler_;
}
void main(string[] args) {
bool trace;
ushort port = 9090;
getopt(args, "port", &port, "trace", &trace);
auto serverCancellation = new TCancellationOrigin;
scope (exit) serverCancellation.trigger();
immutable ports = cast(immutable)array(map!"cast(ushort)a"(iota(port, port + 6)));
version (none) {
// Cannot use this due to multiple DMD @@BUG@@s:
// 1. »function D main is a nested function and cannot be accessed from array«
// when calling array() on the result of the outer map() would have to
// manually do the eager evaluation/array conversion.
// 2. »Zip.opSlice cannot get frame pointer to map« for the delay argument,
// can be worked around by calling array() on the map result first.
// 3. Even when using the workarounds for the last two points, the DMD-built
// executable crashes when building without (sic!) inlining enabled,
// the backtrace points into the first delegate literal.
auto handlers = array(map!((args){
return new ExTestHandler(args._0, args._1, args._2, trace);
})(zip(
ports,
map!((a){ return dur!`msecs`(a); })([1, 10, 100, 1, 10, 100]),
[false, false, false, true, true, true]
)));
} else {
auto handlers = [
new ExTestHandler(cast(ushort)(port + 0), dur!"msecs"(1), false, trace),
new ExTestHandler(cast(ushort)(port + 1), dur!"msecs"(10), false, trace),
new ExTestHandler(cast(ushort)(port + 2), dur!"msecs"(100), false, trace),
new ExTestHandler(cast(ushort)(port + 3), dur!"msecs"(1), true, trace),
new ExTestHandler(cast(ushort)(port + 4), dur!"msecs"(10), true, trace),
new ExTestHandler(cast(ushort)(port + 5), dur!"msecs"(100), true, trace)
];
}
// Fire up the server threads.
foreach (h; handlers) (new ServerThread(h, serverCancellation)).start();
// Give the servers some time to get up. This should really be accomplished
// via a barrier here and in the preServe() hook.
Thread.sleep(dur!"msecs"(10));
syncClientPoolTest(ports, handlers);
asyncClientPoolTest(ports, handlers);
asyncFastestClientPoolTest(ports, handlers);
asyncAggregatorTest(ports, handlers);
}
void syncClientPoolTest(const(ushort)[] ports, ExTestHandler[] handlers) {
auto clients = array(map!((a){
return cast(TClientBase!ExTestService)tClient!ExTestService(
tBinaryProtocol(new TSocket("127.0.0.1", a))
);
})(ports));
scope(exit) foreach (c; clients) c.outputProtocol.transport.close();
// Try the case where the first client succeeds.
{
enforce(makePool(clients).getPort() == ports[0]);
}
// Try the case where all clients fail.
{
auto pool = makePool(clients[3 .. $]);
auto e = cast(TCompoundOperationException)collectException(pool.getPort());
enforce(e);
enforce(equal(map!"a.port"(cast(TestServiceException[])e.exceptions),
ports[3 .. $]));
}
// Try the case where the first clients fail, but a later one succeeds.
{
auto pool = makePool(clients[3 .. $] ~ clients[0 .. 3]);
enforce(pool.getPortInArray() == [ports[0]]);
}
// Make sure a client is properly deactivated when it has failed too often.
{
auto pool = makePool(clients);
pool.faultDisableCount = 1;
pool.faultDisableDuration = dur!"msecs"(50);
handlers[0].failing = true;
enforce(pool.getPort() == ports[1]);
handlers[0].failing = false;
enforce(pool.getPort() == ports[1]);
Thread.sleep(dur!"msecs"(50));
enforce(pool.getPort() == ports[0]);
}
}
auto makePool(TClientBase!ExTestService[] clients) {
auto p = tClientPool(clients);
p.permuteClients = false;
p.rpcFaultFilter = (Exception e) {
return (cast(TestServiceException)e !is null);
};
return p;
}
void asyncClientPoolTest(const(ushort)[] ports, ExTestHandler[] handlers) {
auto manager = new TLibeventAsyncManager;
scope (exit) manager.stop(dur!"hnsecs"(0));
auto clients = makeAsyncClients(manager, ports);
scope(exit) foreach (c; clients) c.transport.close();
// Try the case where the first client succeeds.
{
enforce(makeAsyncPool(clients).getPort() == ports[0]);
}
// Try the case where all clients fail.
{
auto pool = makeAsyncPool(clients[3 .. $]);
auto e = cast(TCompoundOperationException)collectException(pool.getPort().waitGet());
enforce(e);
enforce(equal(map!"a.port"(cast(TestServiceException[])e.exceptions),
ports[3 .. $]));
}
// Try the case where the first clients fail, but a later one succeeds.
{
auto pool = makeAsyncPool(clients[3 .. $] ~ clients[0 .. 3]);
enforce(pool.getPortInArray() == [ports[0]]);
}
// Make sure a client is properly deactivated when it has failed too often.
{
auto pool = makeAsyncPool(clients);
pool.faultDisableCount = 1;
pool.faultDisableDuration = dur!"msecs"(50);
handlers[0].failing = true;
enforce(pool.getPort() == ports[1]);
handlers[0].failing = false;
enforce(pool.getPort() == ports[1]);
Thread.sleep(dur!"msecs"(50));
enforce(pool.getPort() == ports[0]);
}
}
auto makeAsyncPool(TAsyncClientBase!ExTestService[] clients) {
auto p = tAsyncClientPool(clients);
p.permuteClients = false;
p.rpcFaultFilter = (Exception e) {
return (cast(TestServiceException)e !is null);
};
return p;
}
auto makeAsyncClients(TLibeventAsyncManager manager, in ushort[] ports) {
// DMD @@BUG@@ workaround: Using array on the lazyHandlers map result leads
// to »function D main is a nested function and cannot be accessed from array«.
// Thus, we manually do the array conversion.
auto lazyClients = map!((a){
return new TAsyncClient!ExTestService(
new TAsyncSocket(manager, "127.0.0.1", a),
new TBufferedTransportFactory,
new TBinaryProtocolFactory!(TBufferedTransport)
);
})(ports);
TAsyncClientBase!ExTestService[] clients;
foreach (c; lazyClients) clients ~= c;
return clients;
}
void asyncFastestClientPoolTest(const(ushort)[] ports, ExTestHandler[] handlers) {
auto manager = new TLibeventAsyncManager;
scope (exit) manager.stop(dur!"hnsecs"(0));
auto clients = makeAsyncClients(manager, ports);
scope(exit) foreach (c; clients) c.transport.close();
// Make sure the fastest client wins, even if they are called in some other
// order.
{
auto result = makeAsyncFastestPool(array(retro(clients))).getPort().waitGet();
enforce(result == ports[0]);
}
// Try the case where all clients fail.
{
auto pool = makeAsyncFastestPool(clients[3 .. $]);
auto e = cast(TCompoundOperationException)collectException(pool.getPort().waitGet());
enforce(e);
enforce(equal(map!"a.port"(cast(TestServiceException[])e.exceptions),
ports[3 .. $]));
}
// Try the case where the first clients fail, but a later one succeeds.
{
auto pool = makeAsyncFastestPool(clients[1 .. $]);
enforce(pool.getPortInArray() == [ports[1]]);
}
}
auto makeAsyncFastestPool(TAsyncClientBase!ExTestService[] clients) {
auto p = tAsyncFastestClientPool(clients);
p.rpcFaultFilter = (Exception e) {
return (cast(TestServiceException)e !is null);
};
return p;
}
void asyncAggregatorTest(const(ushort)[] ports, ExTestHandler[] handlers) {
auto manager = new TLibeventAsyncManager;
scope (exit) manager.stop(dur!"hnsecs"(0));
auto clients = makeAsyncClients(manager, ports);
scope(exit) foreach (c; clients) c.transport.close();
auto aggregator = tAsyncAggregator(
cast(TAsyncClientBase!ExTestService[])clients);
// Test aggregator range interface.
{
auto range = aggregator.getPort().range(dur!"msecs"(50));
enforce(equal(range, ports[0 .. 2][]));
enforce(equal(map!"a.port"(cast(TestServiceException[])range.exceptions),
ports[3 .. $ - 1]));
enforce(range.completedCount == 4);
}
// Test default accumulator for scalars.
{
auto fullResult = aggregator.getPort().accumulate();
enforce(fullResult.waitGet() == ports[0 .. 3]);
auto partialResult = aggregator.getPort().accumulate();
Thread.sleep(dur!"msecs"(20));
enforce(partialResult.finishGet() == ports[0 .. 2]);
}
// Test default accumulator for arrays.
{
auto fullResult = aggregator.getPortInArray().accumulate();
enforce(fullResult.waitGet() == ports[0 .. 3]);
auto partialResult = aggregator.getPortInArray().accumulate();
Thread.sleep(dur!"msecs"(20));
enforce(partialResult.finishGet() == ports[0 .. 2]);
}
// Test custom accumulator.
{
auto fullResult = aggregator.getPort().accumulate!(function(int[] results){
return reduce!"a + b"(results);
})();
enforce(fullResult.waitGet() == ports[0] + ports[1] + ports[2]);
auto partialResult = aggregator.getPort().accumulate!(
function(int[] results, Exception[] exceptions) {
// Return a tuple of the parameters so we can check them outside of
// this function (to verify the values, we need access to »ports«, but
// due to DMD @@BUG5710@@, we can't use a delegate literal).f
return tuple(results, exceptions);
}
)();
Thread.sleep(dur!"msecs"(20));
auto resultTuple = partialResult.finishGet();
enforce(resultTuple._0 == ports[0 .. 2]);
enforce(equal(map!"a.port"(cast(TestServiceException[])resultTuple._1),
ports[3 .. $ - 1]));
}
}

View file

@ -0,0 +1,14 @@
[ req ]
default_bits = 2048
default_keyfile = server-private-key.pem
distinguished_name = req_distinguished_name
x509_extensions = v3_ca
prompt = no
[ req_distinguished_name ]
CN = localhost
[ v3_ca ]
# Add ::1 to the list of allowed IPs so we can use ::1 to explicitly connect
# to localhost via IPv6.
subjectAltName = IP:::1

View file

@ -0,0 +1,70 @@
/**
* An implementation of the mini serialization benchmark also available for
* C++ and Java.
*
* For meaningful results, you might want to make sure that
* the Thrift library is compiled with release build flags,
* e.g. by including the source files with the build instead
* of linking libthriftd:
*
dmd -w -O -release -inline -I../src -Igen-d -ofserialization_benchmark \
$(find ../src/thrift -name '*.d' -not -name index.d) \
gen-d/DebugProtoTest_types.d serialization_benchmark.d
*/
module serialization_benchmark;
import std.datetime : AutoStart, StopWatch;
import std.math : PI;
import std.stdio;
import thrift.protocol.binary;
import thrift.transport.memory;
import thrift.transport.range;
import DebugProtoTest_types;
void main() {
auto buf = new TMemoryBuffer;
enum ITERATIONS = 10_000_000;
{
auto ooe = OneOfEach();
ooe.im_true = true;
ooe.im_false = false;
ooe.a_bite = 0x7f;
ooe.integer16 = 27_000;
ooe.integer32 = 1 << 24;
ooe.integer64 = 6_000_000_000;
ooe.double_precision = PI;
ooe.some_characters = "JSON THIS! \"\1";
ooe.zomg_unicode = "\xd7\n\a\t";
ooe.base64 = "\1\2\3\255";
auto prot = tBinaryProtocol(buf);
auto sw = StopWatch(AutoStart.yes);
foreach (i; 0 .. ITERATIONS) {
buf.reset(120);
ooe.write(prot);
}
sw.stop();
auto msecs = sw.peek().msecs;
writefln("Write: %s ms (%s kHz)", msecs, ITERATIONS / msecs);
}
auto data = buf.getContents().dup;
{
auto readBuf = tInputRangeTransport(data);
auto prot = tBinaryProtocol(readBuf);
auto ooe = OneOfEach();
auto sw = StopWatch(AutoStart.yes);
foreach (i; 0 .. ITERATIONS) {
readBuf.reset(data);
ooe.read(prot);
}
sw.stop();
auto msecs = sw.peek().msecs;
writefln(" Read: %s ms (%s kHz)", msecs, ITERATIONS / msecs);
}
}

View file

@ -0,0 +1,81 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module stress_test_server;
import std.getopt;
import std.parallelism : totalCPUs;
import std.stdio;
import std.typetuple;
import thrift.codegen.processor;
import thrift.protocol.binary;
import thrift.server.base;
import thrift.server.transport.socket;
import thrift.transport.buffered;
import thrift.transport.memory;
import thrift.transport.socket;
import thrift.util.hashset;
import test_utils;
import thrift.test.stress.Service;
class ServiceHandler : Service {
void echoVoid() { return; }
byte echoByte(byte arg) { return arg; }
int echoI32(int arg) { return arg; }
long echoI64(long arg) { return arg; }
byte[] echoList(byte[] arg) { return arg; }
HashSet!byte echoSet(HashSet!byte arg) { return arg; }
byte[byte] echoMap(byte[byte] arg) { return arg; }
string echoString(string arg) {
if (arg != "hello") {
stderr.writefln(`Wrong string received: %s instead of "hello"`, arg);
throw new Exception("Wrong string received.");
}
return arg;
}
}
void main(string[] args) {
ushort port = 9091;
auto serverType = ServerType.threaded;
TransportType transportType;
size_t numIOThreads = 1;
size_t taskPoolSize = totalCPUs;
getopt(args, "port", &port, "server-type", &serverType,
"transport-type", &transportType, "task-pool-size", &taskPoolSize,
"num-io-threads", &numIOThreads);
alias TypeTuple!(TBufferedTransport, TMemoryBuffer) AvailableTransports;
auto processor = new TServiceProcessor!(Service,
staticMap!(TBinaryProtocol, AvailableTransports))(new ServiceHandler());
auto serverSocket = new TServerSocket(port);
auto transportFactory = createTransportFactory(transportType);
auto protocolFactory = new TBinaryProtocolFactory!AvailableTransports;
auto server = createServer(serverType, taskPoolSize, numIOThreads,
processor, serverSocket, transportFactory, protocolFactory);
writefln("Starting %s %s StressTest server on port %s...", transportType,
serverType, port);
server.serve();
writeln("done.");
}

View file

@ -0,0 +1,96 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/**
* Various helpers used by more than a single test.
*/
module test_utils;
import std.parallelism : TaskPool;
import thrift.protocol.base;
import thrift.protocol.processor;
import thrift.server.base;
import thrift.server.nonblocking;
import thrift.server.simple;
import thrift.server.taskpool;
import thrift.server.threaded;
import thrift.server.transport.socket;
import thrift.transport.base;
import thrift.transport.buffered;
import thrift.transport.framed;
import thrift.transport.http;
// This is a likely victim of @@BUG4744@@ when used with command argument
// parsing.
enum ServerType {
simple,
nonblocking,
pooledNonblocking,
taskpool,
threaded
}
TServer createServer(ServerType type, size_t taskPoolSize, size_t numIOThreads,
TProcessor processor, TServerSocket serverTransport,
TTransportFactory transportFactory, TProtocolFactory protocolFactory)
{
final switch (type) {
case ServerType.simple:
return new TSimpleServer(processor, serverTransport,
transportFactory, protocolFactory);
case ServerType.nonblocking:
auto nb = new TNonblockingServer(processor, serverTransport.port,
transportFactory, protocolFactory);
nb.numIOThreads = numIOThreads;
return nb;
case ServerType.pooledNonblocking:
auto nb = new TNonblockingServer(processor, serverTransport.port,
transportFactory, protocolFactory, new TaskPool(taskPoolSize));
nb.numIOThreads = numIOThreads;
return nb;
case ServerType.taskpool:
auto tps = new TTaskPoolServer(processor, serverTransport,
transportFactory, protocolFactory);
tps.taskPool = new TaskPool(taskPoolSize);
return tps;
case ServerType.threaded:
return new TThreadedServer(processor, serverTransport,
transportFactory, protocolFactory);
}
}
enum TransportType {
buffered,
framed,
http,
raw
}
TTransportFactory createTransportFactory(TransportType type) {
final switch (type) {
case TransportType.buffered:
return new TBufferedTransportFactory;
case TransportType.framed:
return new TFramedTransportFactory;
case TransportType.http:
return new TServerHttpTransportFactory;
case TransportType.raw:
return new TTransportFactory;
}
}

View file

@ -0,0 +1,386 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift_test_client;
import std.conv;
import std.datetime;
import std.exception : enforce;
import std.getopt;
import std.stdio;
import std.string;
import std.traits;
import thrift.base;
import thrift.codegen.client;
import thrift.protocol.base;
import thrift.protocol.binary;
import thrift.protocol.compact;
import thrift.protocol.json;
import thrift.transport.base;
import thrift.transport.buffered;
import thrift.transport.framed;
import thrift.transport.http;
import thrift.transport.socket;
import thrift.transport.ssl;
import thrift.util.hashset;
import thrift_test_common;
import thrift.test.ThriftTest;
import thrift.test.ThriftTest_types;
enum TransportType {
buffered,
framed,
http,
raw
}
TProtocol createProtocol(T)(T trans, ProtocolType type) {
final switch (type) {
case ProtocolType.binary:
return tBinaryProtocol(trans);
case ProtocolType.compact:
return tCompactProtocol(trans);
case ProtocolType.json:
return tJsonProtocol(trans);
}
}
void main(string[] args) {
string host = "localhost";
ushort port = 9090;
uint numTests = 1;
bool ssl;
ProtocolType protocolType;
TransportType transportType;
bool trace;
getopt(args,
"numTests|n", &numTests,
"protocol", &protocolType,
"ssl", &ssl,
"transport", &transportType,
"trace", &trace,
"port", &port,
"host", (string _, string value) {
auto parts = split(value, ":");
if (parts.length > 1) {
// IPv6 addresses can contain colons, so take the last part for the
// port.
host = join(parts[0 .. $ - 1], ":");
port = to!ushort(parts[$ - 1]);
} else {
host = value;
}
}
);
port = to!ushort(port);
TSocket socket;
if (ssl) {
auto sslContext = new TSSLContext();
sslContext.ciphers = "ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH";
sslContext.authenticate = true;
sslContext.loadTrustedCertificates("../../../test/keys/CA.pem");
socket = new TSSLSocket(sslContext, host, port);
} else {
socket = new TSocket(host, port);
}
TProtocol protocol;
final switch (transportType) {
case TransportType.buffered:
protocol = createProtocol(new TBufferedTransport(socket), protocolType);
break;
case TransportType.framed:
protocol = createProtocol(new TFramedTransport(socket), protocolType);
break;
case TransportType.http:
protocol = createProtocol(
new TClientHttpTransport(socket, host, "/service"), protocolType);
break;
case TransportType.raw:
protocol = createProtocol(socket, protocolType);
break;
}
auto client = tClient!ThriftTest(protocol);
ulong time_min;
ulong time_max;
ulong time_tot;
StopWatch sw;
foreach(test; 0 .. numTests) {
sw.start();
protocol.transport.open();
if (trace) writefln("Test #%s, connect %s:%s", test + 1, host, port);
if (trace) write("testVoid()");
client.testVoid();
if (trace) writeln(" = void");
if (trace) write("testString(\"Test\")");
string s = client.testString("Test");
if (trace) writefln(" = \"%s\"", s);
enforce(s == "Test");
if (trace) write("testByte(1)");
byte u8 = client.testByte(1);
if (trace) writefln(" = %s", u8);
enforce(u8 == 1);
if (trace) write("testI32(-1)");
int i32 = client.testI32(-1);
if (trace) writefln(" = %s", i32);
enforce(i32 == -1);
if (trace) write("testI64(-34359738368)");
long i64 = client.testI64(-34359738368L);
if (trace) writefln(" = %s", i64);
enforce(i64 == -34359738368L);
if (trace) write("testDouble(-5.2098523)");
double dub = client.testDouble(-5.2098523);
if (trace) writefln(" = %s", dub);
enforce(dub == -5.2098523);
// TODO: add testBinary() call
Xtruct out1;
out1.string_thing = "Zero";
out1.byte_thing = 1;
out1.i32_thing = -3;
out1.i64_thing = -5;
if (trace) writef("testStruct(%s)", out1);
auto in1 = client.testStruct(out1);
if (trace) writefln(" = %s", in1);
enforce(in1 == out1);
if (trace) write("testNest({1, {\"Zero\", 1, -3, -5}), 5}");
Xtruct2 out2;
out2.byte_thing = 1;
out2.struct_thing = out1;
out2.i32_thing = 5;
auto in2 = client.testNest(out2);
in1 = in2.struct_thing;
if (trace) writefln(" = {%s, {\"%s\", %s, %s, %s}, %s}", in2.byte_thing,
in1.string_thing, in1.byte_thing, in1.i32_thing, in1.i64_thing,
in2.i32_thing);
enforce(in2 == out2);
int[int] mapout;
for (int i = 0; i < 5; ++i) {
mapout[i] = i - 10;
}
if (trace) writef("testMap({%s})", mapout);
auto mapin = client.testMap(mapout);
if (trace) writefln(" = {%s}", mapin);
enforce(mapin == mapout);
auto setout = new HashSet!int;
for (int i = -2; i < 3; ++i) {
setout ~= i;
}
if (trace) writef("testSet(%s)", setout);
auto setin = client.testSet(setout);
if (trace) writefln(" = %s", setin);
enforce(setin == setout);
int[] listout;
for (int i = -2; i < 3; ++i) {
listout ~= i;
}
if (trace) writef("testList(%s)", listout);
auto listin = client.testList(listout);
if (trace) writefln(" = %s", listin);
enforce(listin == listout);
{
if (trace) write("testEnum(ONE)");
auto ret = client.testEnum(Numberz.ONE);
if (trace) writefln(" = %s", ret);
enforce(ret == Numberz.ONE);
if (trace) write("testEnum(TWO)");
ret = client.testEnum(Numberz.TWO);
if (trace) writefln(" = %s", ret);
enforce(ret == Numberz.TWO);
if (trace) write("testEnum(THREE)");
ret = client.testEnum(Numberz.THREE);
if (trace) writefln(" = %s", ret);
enforce(ret == Numberz.THREE);
if (trace) write("testEnum(FIVE)");
ret = client.testEnum(Numberz.FIVE);
if (trace) writefln(" = %s", ret);
enforce(ret == Numberz.FIVE);
if (trace) write("testEnum(EIGHT)");
ret = client.testEnum(Numberz.EIGHT);
if (trace) writefln(" = %s", ret);
enforce(ret == Numberz.EIGHT);
}
if (trace) write("testTypedef(309858235082523)");
UserId uid = client.testTypedef(309858235082523L);
if (trace) writefln(" = %s", uid);
enforce(uid == 309858235082523L);
if (trace) write("testMapMap(1)");
auto mm = client.testMapMap(1);
if (trace) writefln(" = {%s}", mm);
// Simply doing == doesn't seem to work for nested AAs.
foreach (key, value; mm) {
enforce(testMapMapReturn[key] == value);
}
foreach (key, value; testMapMapReturn) {
enforce(mm[key] == value);
}
Insanity insane;
insane.userMap[Numberz.FIVE] = 5000;
Xtruct truck;
truck.string_thing = "Truck";
truck.byte_thing = 8;
truck.i32_thing = 8;
truck.i64_thing = 8;
insane.xtructs ~= truck;
if (trace) write("testInsanity()");
auto whoa = client.testInsanity(insane);
if (trace) writefln(" = %s", whoa);
// Commented for now, this is cumbersome to write without opEqual getting
// called on AA comparison.
// enforce(whoa == testInsanityReturn);
{
try {
if (trace) write("client.testException(\"Xception\") =>");
client.testException("Xception");
if (trace) writeln(" void\nFAILURE");
throw new Exception("testException failed.");
} catch (Xception e) {
if (trace) writefln(" {%s, \"%s\"}", e.errorCode, e.message);
}
try {
if (trace) write("client.testException(\"TException\") =>");
client.testException("Xception");
if (trace) writeln(" void\nFAILURE");
throw new Exception("testException failed.");
} catch (TException e) {
if (trace) writefln(" {%s}", e.msg);
}
try {
if (trace) write("client.testException(\"success\") =>");
client.testException("success");
if (trace) writeln(" void");
} catch (Exception e) {
if (trace) writeln(" exception\nFAILURE");
throw new Exception("testException failed.");
}
}
{
try {
if (trace) write("client.testMultiException(\"Xception\", \"test 1\") =>");
auto result = client.testMultiException("Xception", "test 1");
if (trace) writeln(" result\nFAILURE");
throw new Exception("testMultiException failed.");
} catch (Xception e) {
if (trace) writefln(" {%s, \"%s\"}", e.errorCode, e.message);
}
try {
if (trace) write("client.testMultiException(\"Xception2\", \"test 2\") =>");
auto result = client.testMultiException("Xception2", "test 2");
if (trace) writeln(" result\nFAILURE");
throw new Exception("testMultiException failed.");
} catch (Xception2 e) {
if (trace) writefln(" {%s, {\"%s\"}}",
e.errorCode, e.struct_thing.string_thing);
}
try {
if (trace) writef("client.testMultiException(\"success\", \"test 3\") =>");
auto result = client.testMultiException("success", "test 3");
if (trace) writefln(" {{\"%s\"}}", result.string_thing);
} catch (Exception e) {
if (trace) writeln(" exception\nFAILURE");
throw new Exception("testMultiException failed.");
}
}
// Do not run oneway test when doing multiple iterations, as it blocks the
// server for three seconds.
if (numTests == 1) {
if (trace) writef("client.testOneway(3) =>");
auto onewayWatch = StopWatch(AutoStart.yes);
client.testOneway(3);
onewayWatch.stop();
if (onewayWatch.peek().msecs > 200) {
if (trace) {
writefln(" FAILURE - took %s ms", onewayWatch.peek().usecs / 1000.0);
}
throw new Exception("testOneway failed.");
} else {
if (trace) {
writefln(" success - took %s ms", onewayWatch.peek().usecs / 1000.0);
}
}
// Redo a simple test after the oneway to make sure we aren't "off by
// one", which would be the case if the server treated oneway methods
// like normal ones.
if (trace) write("re-test testI32(-1)");
i32 = client.testI32(-1);
if (trace) writefln(" = %s", i32);
}
// Time metering.
sw.stop();
immutable tot = sw.peek().usecs;
if (trace) writefln("Total time: %s us\n", tot);
time_tot += tot;
if (time_min == 0 || tot < time_min) {
time_min = tot;
}
if (tot > time_max) {
time_max = tot;
}
protocol.transport.close();
sw.reset();
}
writeln("All tests done.");
if (numTests > 1) {
auto time_avg = time_tot / numTests;
writefln("Min time: %s us", time_min);
writefln("Max time: %s us", time_max);
writefln("Avg time: %s us", time_avg);
}
}

View file

@ -0,0 +1,92 @@
module thrift_test_common;
import std.stdio;
import thrift.test.ThriftTest_types;
enum ProtocolType {
binary,
compact,
json
}
void writeInsanityReturn(in Insanity[Numberz][UserId] insane) {
write("{");
foreach(key1, value1; insane) {
writef("%s => {", key1);
foreach(key2, value2; value1) {
writef("%s => {", key2);
write("{");
foreach(key3, value3; value2.userMap) {
writef("%s => %s, ", key3, value3);
}
write("}, ");
write("{");
foreach (x; value2.xtructs) {
writef("{\"%s\", %s, %s, %s}, ",
x.string_thing, x.byte_thing, x.i32_thing, x.i64_thing);
}
write("}");
write("}, ");
}
write("}, ");
}
write("}");
}
Insanity[Numberz][UserId] testInsanityReturn;
int[int][int] testMapMapReturn;
static this() {
testInsanityReturn = {
Insanity[Numberz][UserId] insane;
Xtruct hello;
hello.string_thing = "Hello2";
hello.byte_thing = 2;
hello.i32_thing = 2;
hello.i64_thing = 2;
Xtruct goodbye;
goodbye.string_thing = "Goodbye4";
goodbye.byte_thing = 4;
goodbye.i32_thing = 4;
goodbye.i64_thing = 4;
Insanity crazy;
crazy.userMap[Numberz.EIGHT] = 8;
crazy.xtructs ~= goodbye;
Insanity looney;
// The C++ TestServer also assigns these to crazy, but that is probably
// an oversight.
looney.userMap[Numberz.FIVE] = 5;
looney.xtructs ~= hello;
Insanity[Numberz] first_map;
first_map[Numberz.TWO] = crazy;
first_map[Numberz.THREE] = crazy;
insane[1] = first_map;
Insanity[Numberz] second_map;
second_map[Numberz.SIX] = looney;
insane[2] = second_map;
return insane;
}();
testMapMapReturn = {
int[int] pos;
int[int] neg;
for (int i = 1; i < 5; i++) {
pos[i] = i;
neg[-i] = -i;
}
int[int][int] result;
result[4] = pos;
result[-4] = neg;
return result;
}();
}

View file

@ -0,0 +1,93 @@
#!/bin/bash
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
# Runs the D ThriftTest client and servers for all combinations of transport,
# protocol, SSL-mode and server type.
# Pass -k to keep going after failed tests.
CUR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
protocols="binary compact json"
# TODO: fix and enable http
# transports="buffered framed raw http"
transports="buffered framed raw"
servers="simple taskpool threaded"
framed_only_servers="nonblocking pooledNonblocking"
# Don't leave any server instances behind when interrupted (e.g. by Ctrl+C)
# or terminated.
trap "kill $(jobs -p) 2>/dev/null" INT TERM
for protocol in $protocols; do
for ssl in "" " --ssl"; do
for transport in $transports; do
for server in $servers $framed_only_servers; do
case $framed_only_servers in
*$server*) if [ $transport != "framed" ] || [ $ssl != "" ]; then continue; fi;;
esac
args="--transport=$transport --protocol=$protocol$ssl"
${CUR}/thrift_test_server $args --server-type=$server > /dev/null &
server_pid=$!
# Give the server some time to get up and check if it runs (yes, this
# is a huge kludge, should add a connect timeout to test client).
client_rc=-1
if [ "$server" = "taskpool" ]; then
sleep 0.5
else
sleep 0.02
fi
kill -0 $server_pid 2>/dev/null
if [ $? -eq 0 ]; then
${CUR}/thrift_test_client $args --numTests=10 > /dev/null
client_rc=$?
# Temporarily redirect stderr to null to avoid job control messages,
# restore it afterwards.
exec 3>&2
exec 2>/dev/null
kill $server_pid
exec 3>&2
fi
# Get the server exit code (wait should immediately return).
wait $server_pid
server_rc=$?
if [ $client_rc -ne 0 -o $server_rc -eq 1 ]; then
echo -e "\nTests failed for: $args --server-type=$server"
failed="true"
if [ "$1" != "-k" ]; then
exit 1
fi
else
echo -n "."
fi
done
done
done
done
echo
if [ -z "$failed" ]; then
echo "All tests passed."
fi

View file

@ -0,0 +1,286 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift_test_server;
import core.thread : dur, Thread;
import std.algorithm;
import std.exception : enforce;
import std.getopt;
import std.parallelism : totalCPUs;
import std.string;
import std.stdio;
import std.typetuple : TypeTuple, staticMap;
import thrift.base;
import thrift.codegen.processor;
import thrift.protocol.base;
import thrift.protocol.binary;
import thrift.protocol.compact;
import thrift.protocol.json;
import thrift.server.base;
import thrift.server.transport.socket;
import thrift.server.transport.ssl;
import thrift.transport.base;
import thrift.transport.buffered;
import thrift.transport.framed;
import thrift.transport.http;
import thrift.transport.ssl;
import thrift.util.hashset;
import test_utils;
import thrift_test_common;
import thrift.test.ThriftTest_types;
import thrift.test.ThriftTest;
class TestHandler : ThriftTest {
this(bool trace) {
trace_ = trace;
}
override void testVoid() {
if (trace_) writeln("testVoid()");
}
override string testString(string thing) {
if (trace_) writefln("testString(\"%s\")", thing);
return thing;
}
override byte testByte(byte thing) {
if (trace_) writefln("testByte(%s)", thing);
return thing;
}
override int testI32(int thing) {
if (trace_) writefln("testI32(%s)", thing);
return thing;
}
override long testI64(long thing) {
if (trace_) writefln("testI64(%s)", thing);
return thing;
}
override double testDouble(double thing) {
if (trace_) writefln("testDouble(%s)", thing);
return thing;
}
override string testBinary(string thing) {
if (trace_) writefln("testBinary(\"%s\")", thing);
return thing;
}
override bool testBool(bool thing) {
if (trace_) writefln("testBool(\"%s\")", thing);
return thing;
}
override Xtruct testStruct(ref const(Xtruct) thing) {
if (trace_) writefln("testStruct({\"%s\", %s, %s, %s})",
thing.string_thing, thing.byte_thing, thing.i32_thing, thing.i64_thing);
return thing;
}
override Xtruct2 testNest(ref const(Xtruct2) nest) {
auto thing = nest.struct_thing;
if (trace_) writefln("testNest({%s, {\"%s\", %s, %s, %s}, %s})",
nest.byte_thing, thing.string_thing, thing.byte_thing, thing.i32_thing,
thing.i64_thing, nest.i32_thing);
return nest;
}
override int[int] testMap(int[int] thing) {
if (trace_) writefln("testMap({%s})", thing);
return thing;
}
override HashSet!int testSet(HashSet!int thing) {
if (trace_) writefln("testSet({%s})",
join(map!`to!string(a)`(thing[]), ", "));
return thing;
}
override int[] testList(int[] thing) {
if (trace_) writefln("testList(%s)", thing);
return thing;
}
override Numberz testEnum(Numberz thing) {
if (trace_) writefln("testEnum(%s)", thing);
return thing;
}
override UserId testTypedef(UserId thing) {
if (trace_) writefln("testTypedef(%s)", thing);
return thing;
}
override string[string] testStringMap(string[string] thing) {
if (trace_) writefln("testStringMap(%s)", thing);
return thing;
}
override int[int][int] testMapMap(int hello) {
if (trace_) writefln("testMapMap(%s)", hello);
return testMapMapReturn;
}
override Insanity[Numberz][UserId] testInsanity(ref const(Insanity) argument) {
if (trace_) writeln("testInsanity()");
Insanity[Numberz][UserId] ret;
Insanity[Numberz] m1;
Insanity[Numberz] m2;
Insanity tmp;
tmp = cast(Insanity)argument;
m1[Numberz.TWO] = tmp;
m1[Numberz.THREE] = tmp;
m2[Numberz.SIX] = Insanity();
ret[1] = m1;
ret[2] = m2;
return ret;
}
override Xtruct testMulti(byte arg0, int arg1, long arg2, string[short] arg3,
Numberz arg4, UserId arg5)
{
if (trace_) writeln("testMulti()");
return Xtruct("Hello2", arg0, arg1, arg2);
}
override void testException(string arg) {
if (trace_) writefln("testException(%s)", arg);
if (arg == "Xception") {
auto e = new Xception();
e.errorCode = 1001;
e.message = arg;
throw e;
} else if (arg == "TException") {
throw new TException();
} else if (arg == "ApplicationException") {
throw new TException();
}
}
override Xtruct testMultiException(string arg0, string arg1) {
if (trace_) writefln("testMultiException(%s, %s)", arg0, arg1);
if (arg0 == "Xception") {
auto e = new Xception();
e.errorCode = 1001;
e.message = "This is an Xception";
throw e;
} else if (arg0 == "Xception2") {
auto e = new Xception2();
e.errorCode = 2002;
e.struct_thing.string_thing = "This is an Xception2";
throw e;
} else {
return Xtruct(arg1);
}
}
override void testOneway(int sleepFor) {
if (trace_) writefln("testOneway(%s): Sleeping...", sleepFor);
Thread.sleep(dur!"seconds"(sleepFor));
if (trace_) writefln("testOneway(%s): done sleeping!", sleepFor);
}
private:
bool trace_;
}
void main(string[] args) {
ushort port = 9090;
ServerType serverType;
ProtocolType protocolType;
size_t numIOThreads = 1;
TransportType transportType;
bool ssl;
bool trace;
size_t taskPoolSize = totalCPUs;
getopt(args, "port", &port, "protocol", &protocolType, "server-type",
&serverType, "ssl", &ssl, "num-io-threads", &numIOThreads,
"task-pool-size", &taskPoolSize, "trace", &trace,
"transport", &transportType);
if (serverType == ServerType.nonblocking ||
serverType == ServerType.pooledNonblocking
) {
enforce(transportType == TransportType.framed,
"Need to use framed transport with non-blocking server.");
enforce(!ssl, "The non-blocking server does not support SSL yet.");
// Don't wrap the contents into another layer of framing.
transportType = TransportType.raw;
}
version (ThriftTestTemplates) {
// Only exercise the specialized template code paths if explicitly enabled
// to reduce memory consumption on regular test suite runs there should
// not be much that can go wrong with that specifically anyway.
alias TypeTuple!(TBufferedTransport, TFramedTransport, TServerHttpTransport)
AvailableTransports;
alias TypeTuple!(
staticMap!(TBinaryProtocol, AvailableTransports),
staticMap!(TCompactProtocol, AvailableTransports)
) AvailableProtocols;
} else {
alias TypeTuple!() AvailableTransports;
alias TypeTuple!() AvailableProtocols;
}
TProtocolFactory protocolFactory;
final switch (protocolType) {
case ProtocolType.binary:
protocolFactory = new TBinaryProtocolFactory!AvailableTransports;
break;
case ProtocolType.compact:
protocolFactory = new TCompactProtocolFactory!AvailableTransports;
break;
case ProtocolType.json:
protocolFactory = new TJsonProtocolFactory!AvailableTransports;
break;
}
auto processor = new TServiceProcessor!(ThriftTest, AvailableProtocols)(
new TestHandler(trace));
TServerSocket serverSocket;
if (ssl) {
auto sslContext = new TSSLContext();
sslContext.serverSide = true;
sslContext.loadCertificate("../../../test/keys/server.crt");
sslContext.loadPrivateKey("../../../test/keys/server.key");
sslContext.ciphers = "ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH";
serverSocket = new TSSLServerSocket(port, sslContext);
} else {
serverSocket = new TServerSocket(port);
}
auto transportFactory = createTransportFactory(transportType);
auto server = createServer(serverType, numIOThreads, taskPoolSize,
processor, serverSocket, transportFactory, protocolFactory);
writefln("Starting %s/%s %s ThriftTest server %son port %s...", protocolType,
transportType, serverType, ssl ? "(using SSL) ": "", port);
server.serve();
writeln("done.");
}

View file

@ -0,0 +1,803 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/**
* Exercises various transports, combined with the buffered/framed wrappers.
*
* Originally ported from the C++ version, with Windows support code added.
*/
module transport_test;
import core.atomic;
import core.time : Duration;
import core.thread : Thread;
import std.conv : to;
import std.datetime;
import std.exception : enforce;
static import std.file;
import std.getopt;
import std.random : rndGen, uniform, unpredictableSeed;
import std.socket;
import std.stdio;
import std.string;
import std.typetuple;
import thrift.transport.base;
import thrift.transport.buffered;
import thrift.transport.framed;
import thrift.transport.file;
import thrift.transport.http;
import thrift.transport.memory;
import thrift.transport.socket;
import thrift.transport.zlib;
/*
* Size generation helpers used to be able to run the same testing code
* with both constant and random total/chunk sizes.
*/
interface SizeGenerator {
size_t nextSize();
string toString();
}
class ConstantSizeGenerator : SizeGenerator {
this(size_t value) {
value_ = value;
}
override size_t nextSize() {
return value_;
}
override string toString() const {
return to!string(value_);
}
private:
size_t value_;
}
class RandomSizeGenerator : SizeGenerator {
this(size_t min, size_t max) {
min_ = min;
max_ = max;
}
override size_t nextSize() {
return uniform!"[]"(min_, max_);
}
override string toString() const {
return format("rand(%s, %s)", min_, max_);
}
size_t min() const @property {
return min_;
}
size_t max() const @property {
return max_;
}
private:
size_t min_;
size_t max_;
}
/*
* Classes to set up coupled transports
*/
/**
* Helper class to represent a coupled pair of transports.
*
* Data written to the output transport can be read from the input transport.
*
* This is used as the base class for the various coupled transport
* implementations. It shouldn't be used directly.
*/
class CoupledTransports(Transport) if (isTTransport!Transport) {
Transport input;
Transport output;
}
template isCoupledTransports(T) {
static if (is(T _ : CoupledTransports!U, U)) {
enum isCoupledTransports = true;
} else {
enum isCoupledTransports = false;
}
}
/**
* Helper template class for creating coupled transports that wrap
* another transport.
*/
class CoupledWrapperTransports(WrapperTransport, InnerCoupledTransports) if (
isTTransport!WrapperTransport && isCoupledTransports!InnerCoupledTransports
) : CoupledTransports!WrapperTransport {
this() {
inner_ = new InnerCoupledTransports();
if (inner_.input) {
input = new WrapperTransport(inner_.input);
}
if (inner_.output) {
output = new WrapperTransport(inner_.output);
}
}
~this() {
destroy(inner_);
}
private:
InnerCoupledTransports inner_;
}
import thrift.internal.codegen : PApply;
alias PApply!(CoupledWrapperTransports, TBufferedTransport) CoupledBufferedTransports;
alias PApply!(CoupledWrapperTransports, TFramedTransport) CoupledFramedTransports;
alias PApply!(CoupledWrapperTransports, TZlibTransport) CoupledZlibTransports;
/**
* Coupled TMemoryBuffers.
*/
class CoupledMemoryBuffers : CoupledTransports!TMemoryBuffer {
this() {
buf = new TMemoryBuffer;
input = buf;
output = buf;
}
TMemoryBuffer buf;
}
/**
* Coupled TSockets.
*/
class CoupledSocketTransports : CoupledTransports!TSocket {
this() {
auto sockets = socketPair();
input = new TSocket(sockets[0]);
output = new TSocket(sockets[1]);
}
~this() {
input.close();
output.close();
}
}
/**
* Coupled TFileTransports
*/
class CoupledFileTransports : CoupledTransports!TTransport {
this() {
// We actually need the file name of the temp file here, so we can't just
// use the usual tempfile facilities.
do {
fileName_ = tmpDir ~ "/thrift.transport_test." ~ to!string(rndGen().front);
rndGen().popFront();
} while (std.file.exists(fileName_));
writefln("Using temp file: %s", fileName_);
auto writer = new TFileWriterTransport(fileName_);
writer.open();
output = writer;
// Wait until the file has been created.
writer.flush();
auto reader = new TFileReaderTransport(fileName_);
reader.open();
reader.readTimeout(dur!"msecs"(-1));
input = reader;
}
~this() {
input.close();
output.close();
std.file.remove(fileName_);
}
static string tmpDir;
private:
string fileName_;
}
/*
* Test functions
*/
/**
* Test interleaved write and read calls.
*
* Generates a buffer totalSize bytes long, then writes it to the transport,
* and verifies the written data can be read back correctly.
*
* Mode of operation:
* - call wChunkGenerator to figure out how large of a chunk to write
* - call wSizeGenerator to get the size for individual write() calls,
* and do this repeatedly until the entire chunk is written.
* - call rChunkGenerator to figure out how large of a chunk to read
* - call rSizeGenerator to get the size for individual read() calls,
* and do this repeatedly until the entire chunk is read.
* - repeat until the full buffer is written and read back,
* then compare the data read back against the original buffer
*
*
* - If any of the size generators return 0, this means to use the maximum
* possible size.
*
* - If maxOutstanding is non-zero, write chunk sizes will be chosen such that
* there are never more than maxOutstanding bytes waiting to be read back.
*/
void testReadWrite(CoupledTransports)(
size_t totalSize,
SizeGenerator wSizeGenerator,
SizeGenerator rSizeGenerator,
SizeGenerator wChunkGenerator,
SizeGenerator rChunkGenerator,
size_t maxOutstanding
) if (
isCoupledTransports!CoupledTransports
) {
scope transports = new CoupledTransports;
assert(transports.input);
assert(transports.output);
auto wbuf = new ubyte[totalSize];
auto rbuf = new ubyte[totalSize];
// Store some data in wbuf.
foreach (i, ref b; wbuf) {
b = i & 0xff;
}
size_t totalWritten;
size_t totalRead;
while (totalRead < totalSize) {
// Determine how large a chunk of data to write.
auto wChunkSize = wChunkGenerator.nextSize();
if (wChunkSize == 0 || wChunkSize > totalSize - totalWritten) {
wChunkSize = totalSize - totalWritten;
}
// Make sure (totalWritten - totalRead) + wChunkSize is less than
// maxOutstanding.
if (maxOutstanding > 0 &&
wChunkSize > maxOutstanding - (totalWritten - totalRead)) {
wChunkSize = maxOutstanding - (totalWritten - totalRead);
}
// Write the chunk.
size_t chunkWritten = 0;
while (chunkWritten < wChunkSize) {
auto writeSize = wSizeGenerator.nextSize();
if (writeSize == 0 || writeSize > wChunkSize - chunkWritten) {
writeSize = wChunkSize - chunkWritten;
}
transports.output.write(wbuf[totalWritten .. totalWritten + writeSize]);
chunkWritten += writeSize;
totalWritten += writeSize;
}
// Flush the data, so it will be available in the read transport
// Don't flush if wChunkSize is 0. (This should only happen if
// totalWritten == totalSize already, and we're only reading now.)
if (wChunkSize > 0) {
transports.output.flush();
}
// Determine how large a chunk of data to read back.
auto rChunkSize = rChunkGenerator.nextSize();
if (rChunkSize == 0 || rChunkSize > totalWritten - totalRead) {
rChunkSize = totalWritten - totalRead;
}
// Read the chunk.
size_t chunkRead;
while (chunkRead < rChunkSize) {
auto readSize = rSizeGenerator.nextSize();
if (readSize == 0 || readSize > rChunkSize - chunkRead) {
readSize = rChunkSize - chunkRead;
}
size_t bytesRead;
try {
bytesRead = transports.input.read(
rbuf[totalRead .. totalRead + readSize]);
} catch (TTransportException e) {
throw new Exception(format(`read(pos = %s, size = %s) threw ` ~
`exception "%s"; written so far: %s/%s bytes`, totalRead, readSize,
e.msg, totalWritten, totalSize));
}
enforce(bytesRead > 0, format(`read(pos = %s, size = %s) returned %s; ` ~
`written so far: %s/%s bytes`, totalRead, readSize, bytesRead,
totalWritten, totalSize));
chunkRead += bytesRead;
totalRead += bytesRead;
}
}
// make sure the data read back is identical to the data written
if (rbuf != wbuf) {
stderr.writefln("%s vs. %s", wbuf[$ - 4 .. $], rbuf[$ - 4 .. $]);
stderr.writefln("rbuf: %s vs. wbuf: %s", rbuf.length, wbuf.length);
}
enforce(rbuf == wbuf);
}
void testReadPartAvailable(CoupledTransports)() if (
isCoupledTransports!CoupledTransports
) {
scope transports = new CoupledTransports;
assert(transports.input);
assert(transports.output);
ubyte[10] writeBuf = 'a';
ubyte[10] readBuf;
// Attemping to read 10 bytes when only 9 are available should return 9
// immediately.
transports.output.write(writeBuf[0 .. 9]);
transports.output.flush();
auto t = Trigger(dur!"seconds"(3), transports.output, 1);
auto bytesRead = transports.input.read(readBuf);
enforce(t.fired == 0);
enforce(bytesRead == 9);
}
void testReadPartialMidframe(CoupledTransports)() if (
isCoupledTransports!CoupledTransports
) {
scope transports = new CoupledTransports;
assert(transports.input);
assert(transports.output);
ubyte[13] writeBuf = 'a';
ubyte[14] readBuf;
// Attempt to read 10 bytes, when only 9 are available, but after we have
// already read part of the data that is available. This exercises a
// different code path for several of the transports.
//
// For transports that add their own framing (e.g., TFramedTransport and
// TFileTransport), the two flush calls break up the data in to a 10 byte
// frame and a 3 byte frame. The first read then puts us partway through the
// first frame, and then we attempt to read past the end of that frame, and
// through the next frame, too.
//
// For buffered transports that perform read-ahead (e.g.,
// TBufferedTransport), the read-ahead will most likely see all 13 bytes
// written on the first read. The next read will then attempt to read past
// the end of the read-ahead buffer.
//
// Flush 10 bytes, then 3 bytes. This creates 2 separate frames for
// transports that track framing internally.
transports.output.write(writeBuf[0 .. 10]);
transports.output.flush();
transports.output.write(writeBuf[10 .. 13]);
transports.output.flush();
// Now read 4 bytes, so that we are partway through the written data.
auto bytesRead = transports.input.read(readBuf[0 .. 4]);
enforce(bytesRead == 4);
// Now attempt to read 10 bytes. Only 9 more are available.
//
// We should be able to get all 9 bytes, but it might take multiple read
// calls, since it is valid for read() to return fewer bytes than requested.
// (Most transports do immediately return 9 bytes, but the framing transports
// tend to only return to the end of the current frame, which is 6 bytes in
// this case.)
size_t totalRead = 0;
while (totalRead < 9) {
auto t = Trigger(dur!"seconds"(3), transports.output, 1);
bytesRead = transports.input.read(readBuf[4 + totalRead .. 14]);
enforce(t.fired == 0);
enforce(bytesRead > 0);
totalRead += bytesRead;
enforce(totalRead <= 9);
}
enforce(totalRead == 9);
}
void testBorrowPartAvailable(CoupledTransports)() if (
isCoupledTransports!CoupledTransports
) {
scope transports = new CoupledTransports;
assert(transports.input);
assert(transports.output);
ubyte[9] writeBuf = 'a';
ubyte[10] readBuf;
// Attemping to borrow 10 bytes when only 9 are available should return NULL
// immediately.
transports.output.write(writeBuf);
transports.output.flush();
auto t = Trigger(dur!"seconds"(3), transports.output, 1);
auto borrowLen = readBuf.length;
auto borrowedBuf = transports.input.borrow(readBuf.ptr, borrowLen);
enforce(t.fired == 0);
enforce(borrowedBuf is null);
}
void testReadNoneAvailable(CoupledTransports)() if (
isCoupledTransports!CoupledTransports
) {
scope transports = new CoupledTransports;
assert(transports.input);
assert(transports.output);
// Attempting to read when no data is available should either block until
// some data is available, or fail immediately. (e.g., TSocket blocks,
// TMemoryBuffer just fails.)
//
// If the transport blocks, it should succeed once some data is available,
// even if less than the amount requested becomes available.
ubyte[10] readBuf;
auto t = Trigger(dur!"seconds"(1), transports.output, 2);
t.add(dur!"seconds"(1), transports.output, 8);
auto bytesRead = transports.input.read(readBuf);
if (bytesRead == 0) {
enforce(t.fired == 0);
} else {
enforce(t.fired == 1);
enforce(bytesRead == 2);
}
}
void testBorrowNoneAvailable(CoupledTransports)() if (
isCoupledTransports!CoupledTransports
) {
scope transports = new CoupledTransports;
assert(transports.input);
assert(transports.output);
ubyte[16] writeBuf = 'a';
// Attempting to borrow when no data is available should fail immediately
auto t = Trigger(dur!"seconds"(1), transports.output, 10);
auto borrowLen = 10;
auto borrowedBuf = transports.input.borrow(null, borrowLen);
enforce(borrowedBuf is null);
enforce(t.fired == 0);
}
void doRwTest(CoupledTransports)(
size_t totalSize,
SizeGenerator wSizeGen,
SizeGenerator rSizeGen,
SizeGenerator wChunkSizeGen = new ConstantSizeGenerator(0),
SizeGenerator rChunkSizeGen = new ConstantSizeGenerator(0),
size_t maxOutstanding = 0
) if (
isCoupledTransports!CoupledTransports
) {
totalSize = cast(size_t)(totalSize * g_sizeMultiplier);
scope(failure) {
writefln("Test failed for %s: testReadWrite(%s, %s, %s, %s, %s, %s)",
CoupledTransports.stringof, totalSize, wSizeGen, rSizeGen,
wChunkSizeGen, rChunkSizeGen, maxOutstanding);
}
testReadWrite!CoupledTransports(totalSize, wSizeGen, rSizeGen,
wChunkSizeGen, rChunkSizeGen, maxOutstanding);
}
void doBlockingTest(CoupledTransports)() if (
isCoupledTransports!CoupledTransports
) {
void writeFailure(string name) {
writefln("Test failed for %s: %s()", CoupledTransports.stringof, name);
}
{
scope(failure) writeFailure("testReadPartAvailable");
testReadPartAvailable!CoupledTransports();
}
{
scope(failure) writeFailure("testReadPartialMidframe");
testReadPartialMidframe!CoupledTransports();
}
{
scope(failure) writeFailure("testReadNoneAvaliable");
testReadNoneAvailable!CoupledTransports();
}
{
scope(failure) writeFailure("testBorrowPartAvailable");
testBorrowPartAvailable!CoupledTransports();
}
{
scope(failure) writeFailure("testBorrowNoneAvailable");
testBorrowNoneAvailable!CoupledTransports();
}
}
SizeGenerator getGenerator(T)(T t) {
static if (is(T : SizeGenerator)) {
return t;
} else {
return new ConstantSizeGenerator(t);
}
}
template WrappedTransports(T) if (isCoupledTransports!T) {
alias TypeTuple!(
T,
CoupledBufferedTransports!T,
CoupledFramedTransports!T,
CoupledZlibTransports!T
) WrappedTransports;
}
void testRw(C, R, S)(
size_t totalSize,
R wSize,
S rSize
) if (
isCoupledTransports!C && is(typeof(getGenerator(wSize))) &&
is(typeof(getGenerator(rSize)))
) {
testRw!C(totalSize, wSize, rSize, 0, 0, 0);
}
void testRw(C, R, S, T, U)(
size_t totalSize,
R wSize,
S rSize,
T wChunkSize,
U rChunkSize,
size_t maxOutstanding = 0
) if (
isCoupledTransports!C && is(typeof(getGenerator(wSize))) &&
is(typeof(getGenerator(rSize))) && is(typeof(getGenerator(wChunkSize))) &&
is(typeof(getGenerator(rChunkSize)))
) {
foreach (T; WrappedTransports!C) {
doRwTest!T(
totalSize,
getGenerator(wSize),
getGenerator(rSize),
getGenerator(wChunkSize),
getGenerator(rChunkSize),
maxOutstanding
);
}
}
void testBlocking(C)() if (isCoupledTransports!C) {
foreach (T; WrappedTransports!C) {
doBlockingTest!T();
}
}
// A quick hack, for the sake of brevity…
float g_sizeMultiplier = 1;
version (Posix) {
immutable defaultTempDir = "/tmp";
} else version (Windows) {
import core.sys.windows.windows;
extern(Windows) DWORD GetTempPathA(DWORD nBufferLength, LPTSTR lpBuffer);
string defaultTempDir() @property {
char[MAX_PATH + 1] dir;
enforce(GetTempPathA(dir.length, dir.ptr));
return to!string(dir.ptr)[0 .. $ - 1];
}
} else static assert(false);
void main(string[] args) {
int seed = unpredictableSeed();
string tmpDir = defaultTempDir;
getopt(args, "seed", &seed, "size-multiplier", &g_sizeMultiplier,
"tmp-dir", &tmpDir);
enforce(g_sizeMultiplier >= 0, "Size multiplier must not be negative.");
writefln("Using seed: %s", seed);
rndGen().seed(seed);
CoupledFileTransports.tmpDir = tmpDir;
auto rand4k = new RandomSizeGenerator(1, 4096);
/*
* We do the basically the same set of tests for each transport type,
* although we tweak the parameters in some places.
*/
// TMemoryBuffer tests
testRw!CoupledMemoryBuffers(1024 * 1024, 0, 0);
testRw!CoupledMemoryBuffers(1024 * 256, rand4k, rand4k);
testRw!CoupledMemoryBuffers(1024 * 256, 167, 163);
testRw!CoupledMemoryBuffers(1024 * 16, 1, 1);
testRw!CoupledMemoryBuffers(1024 * 256, 0, 0, rand4k, rand4k);
testRw!CoupledMemoryBuffers(1024 * 256, rand4k, rand4k, rand4k, rand4k);
testRw!CoupledMemoryBuffers(1024 * 256, 167, 163, rand4k, rand4k);
testRw!CoupledMemoryBuffers(1024 * 16, 1, 1, rand4k, rand4k);
testBlocking!CoupledMemoryBuffers();
// TSocket tests
enum socketMaxOutstanding = 4096;
testRw!CoupledSocketTransports(1024 * 1024, 0, 0,
0, 0, socketMaxOutstanding);
testRw!CoupledSocketTransports(1024 * 256, rand4k, rand4k,
0, 0, socketMaxOutstanding);
testRw!CoupledSocketTransports(1024 * 256, 167, 163,
0, 0, socketMaxOutstanding);
// Doh. Apparently writing to a socket has some additional overhead for
// each send() call. If we have more than ~400 outstanding 1-byte write
// requests, additional send() calls start blocking.
testRw!CoupledSocketTransports(1024 * 16, 1, 1,
0, 0, 250);
testRw!CoupledSocketTransports(1024 * 256, 0, 0,
rand4k, rand4k, socketMaxOutstanding);
testRw!CoupledSocketTransports(1024 * 256, rand4k, rand4k,
rand4k, rand4k, socketMaxOutstanding);
testRw!CoupledSocketTransports(1024 * 256, 167, 163,
rand4k, rand4k, socketMaxOutstanding);
testRw!CoupledSocketTransports(1024 * 16, 1, 1,
rand4k, rand4k, 250);
testBlocking!CoupledSocketTransports();
// File transport tests.
// Cannot write more than the frame size at once.
enum maxWriteAtOnce = 1024 * 1024 * 16 - 4;
testRw!CoupledFileTransports(1024 * 1024, maxWriteAtOnce, 0);
testRw!CoupledFileTransports(1024 * 256, rand4k, rand4k);
testRw!CoupledFileTransports(1024 * 256, 167, 163);
testRw!CoupledFileTransports(1024 * 16, 1, 1);
testRw!CoupledFileTransports(1024 * 256, 0, 0, rand4k, rand4k);
testRw!CoupledFileTransports(1024 * 256, rand4k, rand4k, rand4k, rand4k);
testRw!CoupledFileTransports(1024 * 256, 167, 163, rand4k, rand4k);
testRw!CoupledFileTransports(1024 * 16, 1, 1, rand4k, rand4k);
testBlocking!CoupledFileTransports();
}
/*
* Timer handling code for use in tests that check the transport blocking
* semantics.
*
* The implementation has been hacked together in a hurry and wastes a lot of
* threads, but speed should not be the concern here.
*/
struct Trigger {
this(Duration timeout, TTransport transport, size_t writeLength) {
mutex_ = new Mutex;
cancelCondition_ = new Condition(mutex_);
info_ = new Info(timeout, transport, writeLength);
startThread();
}
~this() {
synchronized (mutex_) {
info_ = null;
cancelCondition_.notifyAll();
}
if (thread_) thread_.join();
}
@disable this(this) { assert(0); }
void add(Duration timeout, TTransport transport, size_t writeLength) {
synchronized (mutex_) {
auto info = new Info(timeout, transport, writeLength);
if (info_) {
auto prev = info_;
while (prev.next) prev = prev.next;
prev.next = info;
} else {
info_ = info;
startThread();
}
}
}
@property short fired() {
return atomicLoad(fired_);
}
private:
void timerThread() {
// KLUDGE: Make sure the std.concurrency mbox is initialized on the timer
// thread to be able to unblock the file transport.
import std.concurrency;
thisTid;
synchronized (mutex_) {
while (info_) {
auto cancelled = cancelCondition_.wait(info_.timeout);
if (cancelled) {
info_ = null;
break;
}
atomicOp!"+="(fired_, 1);
// Write some data to the transport to unblock it.
auto buf = new ubyte[info_.writeLength];
buf[] = 'b';
info_.transport.write(buf);
info_.transport.flush();
info_ = info_.next;
}
}
thread_ = null;
}
void startThread() {
thread_ = new Thread(&timerThread);
thread_.start();
}
struct Info {
this(Duration timeout, TTransport transport, size_t writeLength) {
this.timeout = timeout;
this.transport = transport;
this.writeLength = writeLength;
}
Duration timeout;
TTransport transport;
size_t writeLength;
Info* next;
}
Info* info_;
Thread thread_;
shared short fired_;
import core.sync.mutex;
Mutex mutex_;
import core.sync.condition;
Condition cancelCondition_;
}