Checking in vendor folder for ease of using go get.

This commit is contained in:
Renan DelValle 2018-10-23 23:32:59 -07:00
parent 7a1251853b
commit cdb4b5a1d0
No known key found for this signature in database
GPG key ID: C240AD6D6F443EC9
3554 changed files with 1270116 additions and 0 deletions

View file

@ -0,0 +1,228 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/**
* Defines the interface used for client-side handling of asynchronous
* I/O operations, based on coroutines.
*
* The main piece of the »client side« (e.g. for TAsyncClient users) of the
* API is TFuture, which represents an asynchronously executed operation,
* which can have a return value, throw exceptions, and which can be waited
* upon.
*
* On the »implementation side«, the idea is that by using a TAsyncTransport
* instead of a normal TTransport and executing the work through a
* TAsyncManager, the same code as for synchronous I/O can be used for
* asynchronous operation as well, for example:
*
* ---
* auto socket = new TAsyncSocket(someTAsyncSocketManager(), host, port);
* // …
* socket.asyncManager.execute(socket, {
* SomeThriftStruct s;
*
* // Waiting for socket I/O will not block an entire thread but cause
* // the async manager to execute another task in the meantime, because
* // we are using TAsyncSocket instead of TSocket.
* s.read(socket);
*
* // Do something with s, e.g. set a TPromise result to it.
* writeln(s);
* });
* ---
*/
module thrift.async.base;
import core.time : Duration, dur;
import std.socket/+ : Socket+/; // DMD @@BUG314@@
import thrift.base;
import thrift.transport.base;
import thrift.util.cancellation;
/**
* Manages one or more asynchronous transport resources (e.g. sockets in the
* case of TAsyncSocketManager) and allows work items to be submitted for them.
*
* Implementations will typically run one or more background threads for
* executing the work, which is one of the reasons for a TAsyncManager to be
* used. Each work item is run in its own fiber and is expected to yield() away
* while waiting for time-consuming operations.
*
* The second important purpose of TAsyncManager is to serialize access to
* the transport resources without taking care of that, e.g. issuing multiple
* RPC calls over the same connection in rapid succession would likely lead to
* more than one request being written at the same time, causing only garbage
* to arrive at the remote end.
*
* All methods are thread-safe.
*/
interface TAsyncManager {
/**
* Submits a work item to be executed asynchronously.
*
* Access to asnyc transports is serialized if two work items associated
* with the same transport are submitted, the second delegate will not be
* invoked until the first has returned, even it the latter context-switches
* away (because it is waiting for I/O) and the async manager is idle
* otherwise.
*
* Optionally, a TCancellation instance can be specified. If present,
* triggering it will be considered a request to cancel the work item, if it
* is still waiting for the associated transport to become available.
* Delegates which are already being processed (i.e. waiting for I/O) are not
* affected because this would bring the connection into an undefined state
* (as probably half-written request or a half-read response would be left
* behind).
*
* Params:
* transport = The TAsyncTransport the work delegate will operate on. Must
* be associated with this TAsyncManager instance.
* work = The operations to execute on the given transport. Must never
* throw, errors should be handled in another way. nothrow semantics are
* difficult to enforce in combination with fibres though, so currently
* exceptions are just swallowed by TAsyncManager implementations.
* cancellation = If set, can be used to request cancellatinon of this work
* item if it is still waiting to be executed.
*
* Note: The work item will likely be executed in a different thread, so make
* sure the code it relies on is thread-safe. An exception are the async
* transports themselves, to which access is serialized as noted above.
*/
void execute(TAsyncTransport transport, void delegate() work,
TCancellation cancellation = null
) in {
assert(transport.asyncManager is this,
"The given transport must be associated with this TAsyncManager.");
}
/**
* Submits a delegate to be executed after a certain amount of time has
* passed.
*
* The actual amount of time elapsed can be higher if the async manager
* instance is busy and thus should not be relied on. The
*
* Params:
* duration = The amount of time to wait before starting to execute the
* work delegate.
* work = The code to execute after the specified amount of time has passed.
*
* Example:
* ---
* // A very basic example usually, the actuall work item would enqueue
* // some async transport operation.
* auto asyncMangager = someAsyncManager();
*
* TFuture!int calculate() {
* // Create a promise and asynchronously set its value after three
* // seconds have passed.
* auto promise = new TPromise!int;
* asyncManager.delay(dur!"seconds"(3), {
* promise.succeed(42);
* });
*
* // Immediately return it to the caller.
* return promise;
* }
*
* // This will wait until the result is available and then print it.
* writeln(calculate().waitGet());
* ---
*/
void delay(Duration duration, void delegate() work);
/**
* Shuts down all background threads or other facilities that might have
* been started in order to execute work items. This function is typically
* called during program shutdown.
*
* If there are still tasks to be executed when the timeout expires, any
* currently executed work items will never receive any notifications
* for async transports managed by this instance, queued work items will
* be silently dropped, and implementations are allowed to leak resources.
*
* Params:
* waitFinishTimeout = If positive, waits for all work items to be
* finished for the specified amount of time, if negative, waits for
* completion without ever timing out, if zero, immediately shuts down
* the background facilities.
*/
bool stop(Duration waitFinishTimeout = dur!"hnsecs"(-1));
}
/**
* A TTransport which uses a TAsyncManager to schedule non-blocking operations.
*
* The actual type of device is not specified; typically, implementations will
* depend on an interface derived from TAsyncManager to be notified of changes
* in the transport state.
*
* The peeking, reading, writing and flushing methods must always be called
* from within the associated async manager.
*/
interface TAsyncTransport : TTransport {
/**
* The TAsyncManager associated with this transport.
*/
TAsyncManager asyncManager() @property;
}
/**
* A TAsyncManager providing notificiations for socket events.
*/
interface TAsyncSocketManager : TAsyncManager {
/**
* Adds a listener that is triggered once when an event of the specified type
* occurs, and removed afterwards.
*
* Params:
* socket = The socket to listen for events at.
* eventType = The type of the event to listen for.
* timeout = The period of time after which the listener will be called
* with TAsyncEventReason.TIMED_OUT if no event happened.
* listener = The delegate to call when an event happened.
*/
void addOneshotListener(Socket socket, TAsyncEventType eventType,
Duration timeout, TSocketEventListener listener);
/// Ditto
void addOneshotListener(Socket socket, TAsyncEventType eventType,
TSocketEventListener listener);
}
/**
* Types of events that can happen for an asynchronous transport.
*/
enum TAsyncEventType {
READ, /// New data became available to read.
WRITE /// The transport became ready to be written to.
}
/**
* The type of the delegates used to register socket event handlers.
*/
alias void delegate(TAsyncEventReason callReason) TSocketEventListener;
/**
* The reason a listener was called.
*/
enum TAsyncEventReason : byte {
NORMAL, /// The event listened for was triggered normally.
TIMED_OUT /// A timeout for the event was set, and it expired.
}

View file

@ -0,0 +1,461 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.async.libevent;
import core.atomic;
import core.time : Duration, dur;
import core.exception : onOutOfMemoryError;
import core.memory : GC;
import core.thread : Fiber, Thread;
import core.sync.condition;
import core.sync.mutex;
import core.stdc.stdlib : free, malloc;
import deimos.event2.event;
import std.array : empty, front, popFront;
import std.conv : text, to;
import std.exception : enforce;
import std.socket : Socket, socketPair;
import thrift.base;
import thrift.async.base;
import thrift.internal.socket;
import thrift.internal.traits;
import thrift.util.cancellation;
// To avoid DMD @@BUG6395@@.
import thrift.internal.algorithm;
/**
* A TAsyncManager implementation based on libevent.
*
* The libevent loop for handling non-blocking sockets is run in a background
* thread, which is lazily spawned. The thread is not daemonized to avoid
* crashes on program shutdown, it is only stopped when the manager instance
* is destroyed. So, to ensure a clean program teardown, either make sure this
* instance gets destroyed (e.g. by using scope), or manually call stop() at
* the end.
*/
class TLibeventAsyncManager : TAsyncSocketManager {
this() {
eventBase_ = event_base_new();
// Set up the socket pair for transferring control messages to the event
// loop.
auto pair = socketPair();
controlSendSocket_ = pair[0];
controlReceiveSocket_ = pair[1];
controlReceiveSocket_.blocking = false;
// Register an event for receiving control messages.
controlReceiveEvent_ = event_new(eventBase_, controlReceiveSocket_.handle,
EV_READ | EV_PERSIST | EV_ET, assumeNothrow(&controlMsgReceiveCallback),
cast(void*)this);
event_add(controlReceiveEvent_, null);
queuedCountMutex_ = new Mutex;
zeroQueuedCondition_ = new Condition(queuedCountMutex_);
}
~this() {
// stop() should be safe to call, because either we don't have a worker
// thread running and it is a no-op anyway, or it is guaranteed to be
// still running (blocked in event_base_loop), and thus guaranteed not to
// be garbage collected yet.
stop(dur!"hnsecs"(0));
event_free(controlReceiveEvent_);
event_base_free(eventBase_);
eventBase_ = null;
}
override void execute(TAsyncTransport transport, Work work,
TCancellation cancellation = null
) {
if (cancellation && cancellation.triggered) return;
// Keep track that there is a new work item to be processed.
incrementQueuedCount();
ensureWorkerThreadRunning();
// We should be able to send the control message as a whole we currently
// assume to be able to receive it at once as well. If this proves to be
// unstable (e.g. send could possibly return early if the receiving buffer
// is full and the blocking call gets interrupted by a signal), it could
// be changed to a more sophisticated scheme.
// Make sure the delegate context doesn't get GCd while the work item is
// on the wire.
GC.addRoot(work.ptr);
// Send work message.
sendControlMsg(ControlMsg(MsgType.WORK, work, transport));
if (cancellation) {
cancellation.triggering.addCallback({
sendControlMsg(ControlMsg(MsgType.CANCEL, work, transport));
});
}
}
override void delay(Duration duration, void delegate() work) {
incrementQueuedCount();
ensureWorkerThreadRunning();
const tv = toTimeval(duration);
// DMD @@BUG@@: Cannot deduce T to void delegate() here.
registerOneshotEvent!(void delegate())(
-1, 0, assumeNothrow(&delayCallback), &tv,
{
work();
decrementQueuedCount();
}
);
}
override bool stop(Duration waitFinishTimeout = dur!"hnsecs"(-1)) {
bool cleanExit = true;
synchronized (this) {
if (workerThread_) {
synchronized (queuedCountMutex_) {
if (waitFinishTimeout > dur!"hnsecs"(0)) {
if (queuedCount_ > 0) {
zeroQueuedCondition_.wait(waitFinishTimeout);
}
} else if (waitFinishTimeout < dur!"hnsecs"(0)) {
while (queuedCount_ > 0) zeroQueuedCondition_.wait();
} else {
// waitFinishTimeout is zero, immediately exit in all cases.
}
cleanExit = (queuedCount_ == 0);
}
event_base_loopbreak(eventBase_);
sendControlMsg(ControlMsg(MsgType.SHUTDOWN));
workerThread_.join();
workQueues_ = null;
// We have nuked all currently enqueued items, so set the count to
// zero. This is safe to do without locking, since the worker thread
// is down.
queuedCount_ = 0;
atomicStore(*(cast(shared)&workerThread_), cast(shared(Thread))null);
}
}
return cleanExit;
}
override void addOneshotListener(Socket socket, TAsyncEventType eventType,
TSocketEventListener listener
) {
addOneshotListenerImpl(socket, eventType, null, listener);
}
override void addOneshotListener(Socket socket, TAsyncEventType eventType,
Duration timeout, TSocketEventListener listener
) {
if (timeout <= dur!"hnsecs"(0)) {
addOneshotListenerImpl(socket, eventType, null, listener);
} else {
// This is not really documented well, but libevent does not require to
// keep the timeval around after the event was added.
auto tv = toTimeval(timeout);
addOneshotListenerImpl(socket, eventType, &tv, listener);
}
}
private:
alias void delegate() Work;
void addOneshotListenerImpl(Socket socket, TAsyncEventType eventType,
const(timeval)* timeout, TSocketEventListener listener
) {
registerOneshotEvent(socket.handle, libeventEventType(eventType),
assumeNothrow(&socketCallback), timeout, listener);
}
void registerOneshotEvent(T)(evutil_socket_t fd, short type,
event_callback_fn callback, const(timeval)* timeout, T payload
) {
// Create a copy of the payload on the C heap.
auto payloadMem = malloc(payload.sizeof);
if (!payloadMem) onOutOfMemoryError();
(cast(T*)payloadMem)[0 .. 1] = payload;
GC.addRange(payloadMem, payload.sizeof);
auto result = event_base_once(eventBase_, fd, type, callback,
payloadMem, timeout);
// Assuming that we didn't get our arguments wrong above, the only other
// situation in which event_base_once can fail is when it can't allocate
// memory.
if (result != 0) onOutOfMemoryError();
}
enum MsgType : ubyte {
SHUTDOWN,
WORK,
CANCEL
}
struct ControlMsg {
MsgType type;
Work work;
TAsyncTransport transport;
}
/**
* Starts the worker thread if it is not already running.
*/
void ensureWorkerThreadRunning() {
// Technically, only half barriers would be required here, but adding the
// argument seems to trigger a DMD template argument deduction @@BUG@@.
if (!atomicLoad(*(cast(shared)&workerThread_))) {
synchronized (this) {
if (!workerThread_) {
auto thread = new Thread({ event_base_loop(eventBase_, 0); });
thread.start();
atomicStore(*(cast(shared)&workerThread_), cast(shared)thread);
}
}
}
}
/**
* Sends a control message to the worker thread.
*/
void sendControlMsg(const(ControlMsg) msg) {
auto result = controlSendSocket_.send((&msg)[0 .. 1]);
enum size = msg.sizeof;
enforce(result == size, new TException(text(
"Sending control message of type ", msg.type, " failed (", result,
" bytes instead of ", size, " transmitted).")));
}
/**
* Receives messages from the control message socket and acts on them. Called
* from the worker thread.
*/
void receiveControlMsg() {
// Read as many new work items off the socket as possible (at least one
// should be available, as we got notified by libevent).
ControlMsg msg;
ptrdiff_t bytesRead;
while (true) {
bytesRead = controlReceiveSocket_.receive(cast(ubyte[])((&msg)[0 .. 1]));
if (bytesRead < 0) {
auto errno = getSocketErrno();
if (errno != WOULD_BLOCK_ERRNO) {
logError("Reading control message, some work item will possibly " ~
"never be executed: %s", socketErrnoString(errno));
}
}
if (bytesRead != msg.sizeof) break;
// Everything went fine, we received a new control message.
final switch (msg.type) {
case MsgType.SHUTDOWN:
// The message was just intended to wake us up for shutdown.
break;
case MsgType.CANCEL:
// When processing a cancellation, we must not touch the first item,
// since it is already being processed.
auto queue = workQueues_[msg.transport];
if (queue.length > 0) {
workQueues_[msg.transport] = [queue[0]] ~
removeEqual(queue[1 .. $], msg.work);
}
break;
case MsgType.WORK:
// Now that the work item is back in the D world, we don't need the
// extra GC root for the context pointer anymore (see execute()).
GC.removeRoot(msg.work.ptr);
// Add the work item to the queue and execute it.
auto queue = msg.transport in workQueues_;
if (queue is null || (*queue).empty) {
// If the queue is empty, add the new work item to the queue as well,
// but immediately start executing it.
workQueues_[msg.transport] = [msg.work];
executeWork(msg.transport, msg.work);
} else {
(*queue) ~= msg.work;
}
break;
}
}
// If the last read was successful, but didn't read enough bytes, we got
// a problem.
if (bytesRead > 0) {
logError("Unexpected partial control message read (%s byte(s) " ~
"instead of %s), some work item will possibly never be executed.",
bytesRead, msg.sizeof);
}
}
/**
* Executes the given work item and all others enqueued for the same
* transport in a new fiber. Called from the worker thread.
*/
void executeWork(TAsyncTransport transport, Work work) {
(new Fiber({
auto item = work;
while (true) {
try {
// Execute the actual work. It will possibly add listeners to the
// event loop and yield away if it has to wait for blocking
// operations. It is quite possible that another fiber will modify
// the work queue for the current transport.
item();
} catch (Exception e) {
// This should never happen, just to be sure the worker thread
// doesn't stop working in mysterious ways because of an unhandled
// exception.
logError("Exception thrown by work item: %s", e);
}
// Remove the item from the work queue.
// Note: Due to the value semantics of array slices, we have to
// re-lookup this on every iteration. This could be solved, but I'd
// rather replace this directly with a queue type once one becomes
// available in Phobos.
auto queue = workQueues_[transport];
assert(queue.front == item);
queue.popFront();
workQueues_[transport] = queue;
// Now that the work item is done, no longer count it as queued.
decrementQueuedCount();
if (queue.empty) break;
// If the queue is not empty, execute the next waiting item.
item = queue.front;
}
})).call();
}
/**
* Increments the amount of queued items.
*/
void incrementQueuedCount() {
synchronized (queuedCountMutex_) {
++queuedCount_;
}
}
/**
* Decrements the amount of queued items.
*/
void decrementQueuedCount() {
synchronized (queuedCountMutex_) {
assert(queuedCount_ > 0);
--queuedCount_;
if (queuedCount_ == 0) {
zeroQueuedCondition_.notifyAll();
}
}
}
static extern(C) void controlMsgReceiveCallback(evutil_socket_t, short,
void *managerThis
) {
(cast(TLibeventAsyncManager)managerThis).receiveControlMsg();
}
static extern(C) void socketCallback(evutil_socket_t, short flags,
void *arg
) {
auto reason = (flags & EV_TIMEOUT) ? TAsyncEventReason.TIMED_OUT :
TAsyncEventReason.NORMAL;
(*(cast(TSocketEventListener*)arg))(reason);
GC.removeRange(arg);
destroy(arg);
free(arg);
}
static extern(C) void delayCallback(evutil_socket_t, short flags,
void *arg
) {
assert(flags & EV_TIMEOUT);
(*(cast(void delegate()*)arg))();
GC.removeRange(arg);
destroy(arg);
free(arg);
}
Thread workerThread_;
event_base* eventBase_;
/// The socket used for receiving new work items in the event loop. Paired
/// with controlSendSocket_. Invalid (i.e. TAsyncWorkItem.init) items are
/// ignored and can be used to wake up the worker thread.
Socket controlReceiveSocket_;
event* controlReceiveEvent_;
/// The socket used to send new work items to the event loop. It is
/// expected that work items can always be read at once from it, i.e. that
/// there will never be short reads.
Socket controlSendSocket_;
/// Queued up work delegates for async transports. This also includes
/// currently active ones, they are removed from the queue on completion,
/// which is relied on by the control message receive fiber (the main one)
/// to decide whether to immediately start executing items or not.
// TODO: This should really be of some queue type, not an array slice, but
// std.container doesn't have anything.
Work[][TAsyncTransport] workQueues_;
/// The total number of work items not yet finished (queued and currently
/// executed) and delays not yet executed.
uint queuedCount_;
/// Protects queuedCount_.
Mutex queuedCountMutex_;
/// Triggered when queuedCount_ reaches zero, protected by queuedCountMutex_.
Condition zeroQueuedCondition_;
}
private {
timeval toTimeval(const(Duration) dur) {
timeval tv;
dur.split!("seconds", "usecs")(tv.tv_sec, tv.tv_usec);
return tv;
}
/**
* Returns the libevent flags combination to represent a given TAsyncEventType.
*/
short libeventEventType(TAsyncEventType type) {
final switch (type) {
case TAsyncEventType.READ:
return EV_READ | EV_ET;
case TAsyncEventType.WRITE:
return EV_WRITE | EV_ET;
}
}
}

View file

@ -0,0 +1,357 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.async.socket;
import core.thread : Fiber;
import core.time : dur, Duration;
import std.array : empty;
import std.conv : to;
import std.exception : enforce;
import std.socket;
import thrift.base;
import thrift.async.base;
import thrift.transport.base;
import thrift.transport.socket : TSocketBase;
import thrift.internal.endian;
import thrift.internal.socket;
version (Windows) {
import std.c.windows.winsock : connect;
} else version (Posix) {
import core.sys.posix.sys.socket : connect;
} else static assert(0, "Don't know connect on this platform.");
/**
* Non-blocking socket implementation of the TTransport interface.
*
* Whenever a socket operation would block, TAsyncSocket registers a callback
* with the specified TAsyncSocketManager and yields.
*
* As for thrift.transport.socket, due to the limitations of std.socket,
* currently only TCP/IP sockets are supported (i.e. Unix domain sockets are
* not).
*/
class TAsyncSocket : TSocketBase, TAsyncTransport {
/**
* Constructor that takes an already created, connected (!) socket.
*
* Params:
* asyncManager = The TAsyncSocketManager to use for non-blocking I/O.
* socket = Already created, connected socket object. Will be switched to
* non-blocking mode if it isn't already.
*/
this(TAsyncSocketManager asyncManager, Socket socket) {
asyncManager_ = asyncManager;
socket.blocking = false;
super(socket);
}
/**
* Creates a new unconnected socket that will connect to the given host
* on the given port.
*
* Params:
* asyncManager = The TAsyncSocketManager to use for non-blocking I/O.
* host = Remote host.
* port = Remote port.
*/
this(TAsyncSocketManager asyncManager, string host, ushort port) {
asyncManager_ = asyncManager;
super(host, port);
}
override TAsyncManager asyncManager() @property {
return asyncManager_;
}
/**
* Asynchronously connects the socket.
*
* Completes without blocking and defers further operations on the socket
* until the connection is established. If connecting fails, this is
* currently not indicated in any way other than every call to read/write
* failing.
*/
override void open() {
if (isOpen) return;
enforce(!host_.empty, new TTransportException(
"Cannot open null host.", TTransportException.Type.NOT_OPEN));
enforce(port_ != 0, new TTransportException(
"Cannot open with null port.", TTransportException.Type.NOT_OPEN));
// Cannot use std.socket.Socket.connect here because it hides away
// EINPROGRESS/WSAWOULDBLOCK.
Address addr;
try {
// Currently, we just go with the first address returned, could be made
// more intelligent though IPv6?
addr = getAddress(host_, port_)[0];
} catch (Exception e) {
throw new TTransportException(`Unable to resolve host "` ~ host_ ~ `".`,
TTransportException.Type.NOT_OPEN, __FILE__, __LINE__, e);
}
socket_ = new TcpSocket(addr.addressFamily);
socket_.blocking = false;
setSocketOpts();
auto errorCode = connect(socket_.handle, addr.name(), addr.nameLen());
if (errorCode == 0) {
// If the connection could be established immediately, just return. I
// don't know if this ever happens.
return;
}
auto errno = getSocketErrno();
if (errno != CONNECT_INPROGRESS_ERRNO) {
throw new TTransportException(`Could not establish connection to "` ~
host_ ~ `": ` ~ socketErrnoString(errno),
TTransportException.Type.NOT_OPEN);
}
// This is the expected case: connect() signalled that the connection
// is being established in the background. Queue up a work item with the
// async manager which just defers any other operations on this
// TAsyncSocket instance until the socket is ready.
asyncManager_.execute(this,
{
auto fiber = Fiber.getThis();
TAsyncEventReason reason = void;
asyncManager_.addOneshotListener(socket_, TAsyncEventType.WRITE,
connectTimeout,
scopedDelegate((TAsyncEventReason r){ reason = r; fiber.call(); })
);
Fiber.yield();
if (reason == TAsyncEventReason.TIMED_OUT) {
// Close the connection, so that subsequent work items fail immediately.
closeImmediately();
return;
}
int errorCode = void;
socket_.getOption(SocketOptionLevel.SOCKET, cast(SocketOption)SO_ERROR,
errorCode);
if (errorCode) {
logInfo("Could not connect TAsyncSocket: %s",
socketErrnoString(errorCode));
// Close the connection, so that subsequent work items fail immediately.
closeImmediately();
return;
}
}
);
}
/**
* Closes the socket.
*
* Will block until all currently active operations are finished before the
* socket is closed.
*/
override void close() {
if (!isOpen) return;
import core.sync.condition;
import core.sync.mutex;
auto doneMutex = new Mutex;
auto doneCond = new Condition(doneMutex);
synchronized (doneMutex) {
asyncManager_.execute(this,
scopedDelegate(
{
closeImmediately();
synchronized (doneMutex) doneCond.notifyAll();
}
)
);
doneCond.wait();
}
}
override bool peek() {
if (!isOpen) return false;
ubyte buf;
auto r = socket_.receive((&buf)[0..1], SocketFlags.PEEK);
if (r == Socket.ERROR) {
auto lastErrno = getSocketErrno();
static if (connresetOnPeerShutdown) {
if (lastErrno == ECONNRESET) {
closeImmediately();
return false;
}
}
throw new TTransportException("Peeking into socket failed: " ~
socketErrnoString(lastErrno), TTransportException.Type.UNKNOWN);
}
return (r > 0);
}
override size_t read(ubyte[] buf) {
enforce(isOpen, new TTransportException(
"Cannot read if socket is not open.", TTransportException.Type.NOT_OPEN));
typeof(getSocketErrno()) lastErrno;
auto r = yieldOnBlock(socket_.receive(cast(void[])buf),
TAsyncEventType.READ);
// If recv went fine, immediately return.
if (r >= 0) return r;
// Something went wrong, find out how to handle it.
lastErrno = getSocketErrno();
static if (connresetOnPeerShutdown) {
// See top comment.
if (lastErrno == ECONNRESET) {
return 0;
}
}
throw new TTransportException("Receiving from socket failed: " ~
socketErrnoString(lastErrno), TTransportException.Type.UNKNOWN);
}
override void write(in ubyte[] buf) {
size_t sent;
while (sent < buf.length) {
sent += writeSome(buf[sent .. $]);
}
assert(sent == buf.length);
}
override size_t writeSome(in ubyte[] buf) {
enforce(isOpen, new TTransportException(
"Cannot write if socket is not open.", TTransportException.Type.NOT_OPEN));
auto r = yieldOnBlock(socket_.send(buf), TAsyncEventType.WRITE);
// Everything went well, just return the number of bytes written.
if (r > 0) return r;
// Handle error conditions.
if (r < 0) {
auto lastErrno = getSocketErrno();
auto type = TTransportException.Type.UNKNOWN;
if (isSocketCloseErrno(lastErrno)) {
type = TTransportException.Type.NOT_OPEN;
closeImmediately();
}
throw new TTransportException("Sending to socket failed: " ~
socketErrnoString(lastErrno), type);
}
// send() should never return 0.
throw new TTransportException("Sending to socket failed (0 bytes written).",
TTransportException.Type.UNKNOWN);
}
/// The amount of time in which a conncetion must be established before the
/// open() call times out.
Duration connectTimeout = dur!"seconds"(5);
private:
void closeImmediately() {
socket_.close();
socket_ = null;
}
T yieldOnBlock(T)(lazy T call, TAsyncEventType eventType) {
while (true) {
auto result = call();
if (result != Socket.ERROR || getSocketErrno() != WOULD_BLOCK_ERRNO) return result;
// We got an EAGAIN result, register a callback to return here once some
// event happens and yield.
Duration timeout = void;
final switch (eventType) {
case TAsyncEventType.READ:
timeout = recvTimeout_;
break;
case TAsyncEventType.WRITE:
timeout = sendTimeout_;
break;
}
auto fiber = Fiber.getThis();
assert(fiber, "Current fiber null not running in TAsyncManager?");
TAsyncEventReason eventReason = void;
asyncManager_.addOneshotListener(socket_, eventType, timeout,
scopedDelegate((TAsyncEventReason reason) {
eventReason = reason;
fiber.call();
})
);
// Yields execution back to the async manager, will return back here once
// the above listener is called.
Fiber.yield();
if (eventReason == TAsyncEventReason.TIMED_OUT) {
// If we are cancelling the request due to a timed out operation, the
// connection is in an undefined state, because the server could decide
// to send the requested data later, or we could have already been half-
// way into writing a request. Thus, we close the connection to make any
// possibly queued up work items fail immediately. Besides, the server
// is not very likely to immediately recover after a socket-level
// timeout has expired anyway.
closeImmediately();
throw new TTransportException("Timed out while waiting for socket " ~
"to get ready to " ~ to!string(eventType) ~ ".",
TTransportException.Type.TIMED_OUT);
}
}
}
/// The TAsyncSocketManager to use for non-blocking I/O.
TAsyncSocketManager asyncManager_;
}
private {
// std.socket doesn't include SO_ERROR for reasons unknown.
version (linux) {
enum SO_ERROR = 4;
} else version (OSX) {
enum SO_ERROR = 0x1007;
} else version (FreeBSD) {
enum SO_ERROR = 0x1007;
} else version (Win32) {
import std.c.windows.winsock : SO_ERROR;
} else static assert(false, "Don't know SO_ERROR on this platform.");
// This hack forces a delegate literal to be scoped, even if it is passed to
// a function accepting normal delegates as well. DMD likes to allocate the
// context on the heap anyway, but it seems to work for LDC.
import std.traits : isDelegate;
auto scopedDelegate(D)(scope D d) if (isDelegate!D) {
return d;
}
}

View file

@ -0,0 +1,292 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.async.ssl;
import core.thread : Fiber;
import core.time : Duration;
import std.array : empty;
import std.conv : to;
import std.exception : enforce;
import std.socket;
import deimos.openssl.err;
import deimos.openssl.ssl;
import thrift.base;
import thrift.async.base;
import thrift.async.socket;
import thrift.internal.ssl;
import thrift.internal.ssl_bio;
import thrift.transport.base;
import thrift.transport.ssl;
/**
* Provides SSL/TLS encryption for async sockets.
*
* This implementation should be considered experimental, as it context-switches
* between fibers from within OpenSSL calls, and the safety of this has not yet
* been verified.
*
* For obvious reasons (the SSL connection is stateful), more than one instance
* should never be used on a given socket at the same time.
*/
// Note: This could easily be extended to other transports in the future as well.
// There are only two parts of the implementation which don't work with a generic
// TTransport: 1) the certificate verification, for which peer name/address are
// needed from the socket, and 2) the connection shutdown, where the associated
// async manager is needed because close() is not usually called from within a
// work item.
final class TAsyncSSLSocket : TBaseTransport {
/**
* Constructor.
*
* Params:
* context = The SSL socket context to use. A reference to it is stored so
* that it does not get cleaned up while the socket is used.
* transport = The underlying async network transport to use for
* communication.
*/
this(TAsyncSocket underlyingSocket, TSSLContext context) {
socket_ = underlyingSocket;
context_ = context;
serverSide_ = context.serverSide;
accessManager_ = context.accessManager;
}
override bool isOpen() @property {
if (ssl_ is null || !socket_.isOpen) return false;
auto shutdown = SSL_get_shutdown(ssl_);
bool shutdownReceived = (shutdown & SSL_RECEIVED_SHUTDOWN) != 0;
bool shutdownSent = (shutdown & SSL_SENT_SHUTDOWN) != 0;
return !(shutdownReceived && shutdownSent);
}
override bool peek() {
if (!isOpen) return false;
checkHandshake();
byte bt = void;
auto rc = SSL_peek(ssl_, &bt, bt.sizeof);
sslEnforce(rc >= 0, "SSL_peek");
if (rc == 0) {
ERR_clear_error();
}
return (rc > 0);
}
override void open() {
enforce(!serverSide_, "Cannot open a server-side SSL socket.");
if (isOpen) return;
if (ssl_) {
// If the underlying socket was automatically closed because of an error
// (i.e. close() was called from inside a socket method), we can land
// here with the SSL object still allocated; delete it here.
cleanupSSL();
}
socket_.open();
}
override void close() {
if (!isOpen) return;
if (ssl_ !is null) {
// SSL needs to send/receive data over the socket as part of the shutdown
// protocol, so we must execute the calls in the context of the associated
// async manager. On the other hand, TTransport clients expect the socket
// to be closed when close() returns, so we have to block until the
// shutdown work item has been executed.
import core.sync.condition;
import core.sync.mutex;
int rc = void;
auto doneMutex = new Mutex;
auto doneCond = new Condition(doneMutex);
synchronized (doneMutex) {
socket_.asyncManager.execute(socket_, {
rc = SSL_shutdown(ssl_);
if (rc == 0) {
rc = SSL_shutdown(ssl_);
}
synchronized (doneMutex) doneCond.notifyAll();
});
doneCond.wait();
}
if (rc < 0) {
// Do not throw an exception here as leaving the transport "open" will
// probably produce only more errors, and the chance we can do
// something about the error e.g. by retrying is very low.
logError("Error while shutting down SSL: %s", getSSLException());
}
cleanupSSL();
}
socket_.close();
}
override size_t read(ubyte[] buf) {
checkHandshake();
auto rc = SSL_read(ssl_, buf.ptr, cast(int)buf.length);
sslEnforce(rc >= 0, "SSL_read");
return rc;
}
override void write(in ubyte[] buf) {
checkHandshake();
// Loop in case SSL_MODE_ENABLE_PARTIAL_WRITE is set in SSL_CTX.
size_t written = 0;
while (written < buf.length) {
auto bytes = SSL_write(ssl_, buf.ptr + written,
cast(int)(buf.length - written));
sslEnforce(bytes > 0, "SSL_write");
written += bytes;
}
}
override void flush() {
checkHandshake();
auto bio = SSL_get_wbio(ssl_);
enforce(bio !is null, new TSSLException("SSL_get_wbio returned null"));
auto rc = BIO_flush(bio);
sslEnforce(rc == 1, "BIO_flush");
}
/**
* Whether to use client or server side SSL handshake protocol.
*/
bool serverSide() @property const {
return serverSide_;
}
/// Ditto
void serverSide(bool value) @property {
serverSide_ = value;
}
/**
* The access manager to use.
*/
void accessManager(TAccessManager value) @property {
accessManager_ = value;
}
private:
/**
* If the condition is false, cleans up the SSL connection and throws the
* exception for the last SSL error.
*/
void sslEnforce(bool condition, string location) {
if (!condition) {
// We need to fetch the error first, as the error stack will be cleaned
// when shutting down SSL.
auto e = getSSLException(location);
cleanupSSL();
throw e;
}
}
/**
* Frees the SSL connection object and clears the SSL error state.
*/
void cleanupSSL() {
SSL_free(ssl_);
ssl_ = null;
ERR_remove_state(0);
}
/**
* Makes sure the SSL connection is up and running, and initializes it if not.
*/
void checkHandshake() {
enforce(socket_.isOpen, new TTransportException(
TTransportException.Type.NOT_OPEN));
if (ssl_ !is null) return;
ssl_ = context_.createSSL();
auto bio = createTTransportBIO(socket_, false);
SSL_set_bio(ssl_, bio, bio);
int rc = void;
if (serverSide_) {
rc = SSL_accept(ssl_);
} else {
rc = SSL_connect(ssl_);
}
enforce(rc > 0, getSSLException());
auto addr = socket_.getPeerAddress();
authorize(ssl_, accessManager_, addr,
(serverSide_ ? addr.toHostNameString() : socket_.host));
}
TAsyncSocket socket_;
bool serverSide_;
SSL* ssl_;
TSSLContext context_;
TAccessManager accessManager_;
}
/**
* Wraps passed TAsyncSocket instances into TAsyncSSLSockets.
*
* Typically used with TAsyncClient. As an unfortunate consequence of the
* async client design, the passed transports cannot be statically verified to
* be of type TAsyncSocket. Instead, the type is verified at runtime if a
* transport of an unexpected type is passed to getTransport(), it fails,
* throwing a TTransportException.
*
* Example:
* ---
* auto context = nwe TSSLContext();
* ... // Configure SSL context.
* auto factory = new TAsyncSSLSocketFactory(context);
*
* auto socket = new TAsyncSocket(someAsyncManager, host, port);
* socket.open();
*
* auto client = new TAsyncClient!Service(transport, factory,
* new TBinaryProtocolFactory!());
* ---
*/
class TAsyncSSLSocketFactory : TTransportFactory {
///
this(TSSLContext context) {
context_ = context;
}
override TAsyncSSLSocket getTransport(TTransport transport) {
auto socket = cast(TAsyncSocket)transport;
enforce(socket, new TTransportException(
"TAsyncSSLSocketFactory requires a TAsyncSocket to work on, not a " ~
to!string(typeid(transport)) ~ ".",
TTransportException.Type.INTERNAL_ERROR
));
return new TAsyncSSLSocket(socket, context_);
}
private:
TSSLContext context_;
}

View file

@ -0,0 +1,123 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.base;
/**
* Common base class for all Thrift exceptions.
*/
class TException : Exception {
///
this(string msg = "", string file = __FILE__, size_t line = __LINE__,
Throwable next = null)
{
super(msg, file, line, next);
}
}
/**
* An operation failed because one or more sub-tasks failed.
*/
class TCompoundOperationException : TException {
///
this(string msg, Exception[] exceptions, string file = __FILE__,
size_t line = __LINE__, Throwable next = null)
{
super(msg, file, line, next);
this.exceptions = exceptions;
}
/// The exceptions thrown by the children of the operation. If applicable,
/// the list is ordered in the same way the exceptions occurred.
Exception[] exceptions;
}
/// The Thrift version string, used for informative purposes.
// Note: This is currently hardcoded, but will likely be filled in by the build
// system in future versions.
enum VERSION = "0.10.0";
/**
* Functions used for logging inside Thrift.
*
* By default, the formatted messages are written to stdout/stderr, but this
* behavior can be overwritten by providing custom g_{Info, Error}LogSink
* handlers.
*
* Examples:
* ---
* logInfo("An informative message.");
* logError("Some error occurred: %s", e);
* ---
*/
alias logFormatted!g_infoLogSink logInfo;
alias logFormatted!g_errorLogSink logError; /// Ditto
/**
* Error and info log message sinks.
*
* These delegates are called with the log message passed as const(char)[]
* argument, and can be overwritten to hook the Thrift libraries up with a
* custom logging system. By default, they forward all output to stdout/stderr.
*/
__gshared void delegate(const(char)[]) g_infoLogSink;
__gshared void delegate(const(char)[]) g_errorLogSink; /// Ditto
shared static this() {
import std.stdio;
g_infoLogSink = (const(char)[] text) {
stdout.writeln(text);
};
g_errorLogSink = (const(char)[] text) {
stderr.writeln(text);
};
}
// This should be private, if it could still be used through the aliases then.
template logFormatted(alias target) {
void logFormatted(string file = __FILE__, int line = __LINE__,
T...)(string fmt, T args) if (
__traits(compiles, { target(""); })
) {
import std.format, std.stdio;
if (target !is null) {
scope(exit) g_formatBuffer.clear();
// Phobos @@BUG@@: If the empty string put() is removed, Appender.data
// stays empty.
g_formatBuffer.put("");
formattedWrite(g_formatBuffer, "%s:%s: ", file, line);
static if (T.length == 0) {
g_formatBuffer.put(fmt);
} else {
formattedWrite(g_formatBuffer, fmt, args);
}
target(g_formatBuffer.data);
}
}
}
private {
// Use a global, but thread-local buffer for constructing log messages.
import std.array : Appender;
Appender!(char[]) g_formatBuffer;
}

View file

@ -0,0 +1,255 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.codegen.async_client;
import std.conv : text, to;
import std.traits : ParameterStorageClass, ParameterStorageClassTuple,
ParameterTypeTuple, ReturnType;
import thrift.base;
import thrift.async.base;
import thrift.codegen.base;
import thrift.codegen.client;
import thrift.internal.codegen;
import thrift.internal.ctfe;
import thrift.protocol.base;
import thrift.transport.base;
import thrift.util.cancellation;
import thrift.util.future;
/**
* Asynchronous Thrift service client which returns the results as TFutures an
* uses a TAsyncManager to perform the actual work.
*
* TAsyncClientBase serves as a supertype for all TAsyncClients for the same
* service, which might be instantiated with different concrete protocol types
* (there is no covariance for template type parameters), and extends
* TFutureInterface!Interface. If Interface is derived from another service
* BaseInterface, it also extends TAsyncClientBase!BaseInterface.
*
* TAsyncClient implements TAsyncClientBase and offers two constructors with
* the following signatures:
* ---
* this(TAsyncTransport trans, TTransportFactory tf, TProtocolFactory pf);
* this(TAsyncTransport trans, TTransportFactory itf, TTransportFactory otf,
* TProtocolFactory ipf, TProtocolFactory opf);
* ---
*
* Again, if Interface represents a derived Thrift service,
* TAsyncClient!Interface is also derived from TAsyncClient!BaseInterface.
*
* TAsyncClient can exclusively be used with TAsyncTransports, as it needs to
* access the associated TAsyncManager. To set up any wrapper transports
* (e.g. buffered, framed) on top of it and to instanciate the protocols to use,
* TTransportFactory and TProtocolFactory instances are passed to the
* constructors the three argument constructor is a shortcut if the same
* transport and protocol are to be used for both input and output, which is
* the most common case.
*
* If the same transport factory is passed for both input and output transports,
* only a single wrapper transport will be created and used for both directions.
* This allows easy implementation of protocols like SSL.
*
* Just as TClient does, TAsyncClient also takes two optional template
* arguments which can be used for specifying the actual TProtocol
* implementation used for optimization purposes, as virtual calls can
* completely be eliminated then. If the actual types of the protocols
* instantiated by the factories used does not match the ones statically
* specified in the template parameters, a TException is thrown during
* construction.
*
* Example:
* ---
* // A simple Thrift service.
* interface Foo { int foo(); }
*
* // Create a TAsyncSocketManager thrift.async.libevent is used for this
* // example.
* auto manager = new TLibeventAsyncManager;
*
* // Set up an async transport to use.
* auto socket = new TAsyncSocket(manager, host, port);
*
* // Create a client instance.
* auto client = new TAsyncClient!Foo(
* socket,
* new TBufferedTransportFactory, // Wrap the socket in a TBufferedTransport.
* new TBinaryProtocolFactory!() // Use the Binary protocol.
* );
*
* // Call foo and use the returned future.
* auto result = client.foo();
* pragma(msg, typeof(result)); // TFuture!int
* int resultValue = result.waitGet(); // Waits until the result is available.
* ---
*/
interface TAsyncClientBase(Interface) if (isBaseService!Interface) :
TFutureInterface!Interface
{
/**
* The underlying TAsyncTransport used by this client instance.
*/
TAsyncTransport transport() @property;
}
/// Ditto
interface TAsyncClientBase(Interface) if (isDerivedService!Interface) :
TAsyncClientBase!(BaseService!Interface), TFutureInterface!Interface
{}
/// Ditto
template TAsyncClient(Interface, InputProtocol = TProtocol, OutputProtocol = void) if (
isService!Interface && isTProtocol!InputProtocol &&
(isTProtocol!OutputProtocol || is(OutputProtocol == void))
) {
mixin({
static if (isDerivedService!Interface) {
string code = "class TAsyncClient : " ~
"TAsyncClient!(BaseService!Interface, InputProtocol, OutputProtocol), " ~
"TAsyncClientBase!Interface {\n";
code ~= q{
this(TAsyncTransport trans, TTransportFactory tf, TProtocolFactory pf) {
this(trans, tf, tf, pf, pf);
}
this(TAsyncTransport trans, TTransportFactory itf,
TTransportFactory otf, TProtocolFactory ipf, TProtocolFactory opf
) {
super(trans, itf, otf, ipf, opf);
client_ = new typeof(client_)(iprot_, oprot_);
}
private TClient!(Interface, IProt, OProt) client_;
};
} else {
string code = "class TAsyncClient : TAsyncClientBase!Interface {";
code ~= q{
alias InputProtocol IProt;
static if (isTProtocol!OutputProtocol) {
alias OutputProtocol OProt;
} else {
static assert(is(OutputProtocol == void));
alias InputProtocol OProt;
}
this(TAsyncTransport trans, TTransportFactory tf, TProtocolFactory pf) {
this(trans, tf, tf, pf, pf);
}
this(TAsyncTransport trans, TTransportFactory itf,
TTransportFactory otf, TProtocolFactory ipf, TProtocolFactory opf
) {
import std.exception;
transport_ = trans;
auto ip = itf.getTransport(trans);
TTransport op = void;
if (itf == otf) {
op = ip;
} else {
op = otf.getTransport(trans);
}
auto iprot = ipf.getProtocol(ip);
iprot_ = cast(IProt)iprot;
enforce(iprot_, new TException(text("Input protocol not of the " ~
"specified concrete type (", IProt.stringof, ").")));
auto oprot = opf.getProtocol(op);
oprot_ = cast(OProt)oprot;
enforce(oprot_, new TException(text("Output protocol not of the " ~
"specified concrete type (", OProt.stringof, ").")));
client_ = new typeof(client_)(iprot_, oprot_);
}
override TAsyncTransport transport() @property {
return transport_;
}
protected TAsyncTransport transport_;
protected IProt iprot_;
protected OProt oprot_;
private TClient!(Interface, IProt, OProt) client_;
};
}
foreach (methodName;
FilterMethodNames!(Interface, __traits(derivedMembers, Interface))
) {
string[] paramList;
string[] paramNames;
foreach (i, _; ParameterTypeTuple!(mixin("Interface." ~ methodName))) {
immutable paramName = "param" ~ to!string(i + 1);
immutable storage = ParameterStorageClassTuple!(
mixin("Interface." ~ methodName))[i];
paramList ~= ((storage & ParameterStorageClass.ref_) ? "ref " : "") ~
"ParameterTypeTuple!(Interface." ~ methodName ~ ")[" ~
to!string(i) ~ "] " ~ paramName;
paramNames ~= paramName;
}
paramList ~= "TCancellation cancellation = null";
immutable returnTypeCode = "ReturnType!(Interface." ~ methodName ~ ")";
code ~= "TFuture!(" ~ returnTypeCode ~ ") " ~ methodName ~ "(" ~
ctfeJoin(paramList) ~ ") {\n";
// Create the future instance that will repesent the result.
code ~= "auto promise = new TPromise!(" ~ returnTypeCode ~ ");\n";
// Prepare delegate which executes the TClient method call.
code ~= "auto work = {\n";
code ~= "try {\n";
code ~= "static if (is(ReturnType!(Interface." ~ methodName ~
") == void)) {\n";
code ~= "client_." ~ methodName ~ "(" ~ ctfeJoin(paramNames) ~ ");\n";
code ~= "promise.succeed();\n";
code ~= "} else {\n";
code ~= "auto result = client_." ~ methodName ~ "(" ~
ctfeJoin(paramNames) ~ ");\n";
code ~= "promise.succeed(result);\n";
code ~= "}\n";
code ~= "} catch (Exception e) {\n";
code ~= "promise.fail(e);\n";
code ~= "}\n";
code ~= "};\n";
// If the request is cancelled, set the result promise to cancelled
// as well. This could be moved into an additional TAsyncWorkItem
// delegate parameter.
code ~= q{
if (cancellation) {
cancellation.triggering.addCallback({
promise.cancel();
});
}
};
// Enqueue the work item and immediately return the promise (resp. its
// future interface).
code ~= "transport_.asyncManager.execute(transport_, work, cancellation);\n";
code ~= "return promise;\n";
code ~= "}\n";
}
code ~= "}\n";
return code;
}());
}

View file

@ -0,0 +1,906 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/**
* Utilities for asynchronously querying multiple servers, building on
* TAsyncClient.
*
* Terminology note: The names of the artifacts defined in this module are
* derived from »client pool«, because they operate on a pool of
* TAsyncClients. However, from a architectural point of view, they often
* represent a pool of hosts a Thrift client application communicates with
* using RPC calls.
*/
module thrift.codegen.async_client_pool;
import core.sync.mutex;
import core.time : Duration, dur;
import std.algorithm : map;
import std.array : array, empty;
import std.exception : enforce;
import std.traits : ParameterTypeTuple, ReturnType;
import thrift.base;
import thrift.codegen.base;
import thrift.codegen.async_client;
import thrift.internal.algorithm;
import thrift.internal.codegen;
import thrift.util.awaitable;
import thrift.util.cancellation;
import thrift.util.future;
import thrift.internal.resource_pool;
/**
* Represents a generic client pool which implements TFutureInterface!Interface
* using multiple TAsyncClients.
*/
interface TAsyncClientPoolBase(Interface) if (isService!Interface) :
TFutureInterface!Interface
{
/// Shorthand for the client type this pool operates on.
alias TAsyncClientBase!Interface Client;
/**
* Adds a client to the pool.
*/
void addClient(Client client);
/**
* Removes a client from the pool.
*
* Returns: Whether the client was found in the pool.
*/
bool removeClient(Client client);
/**
* Called to determine whether an exception comes from a client from the
* pool not working properly, or if it an exception thrown at the
* application level.
*
* If the delegate returns true, the server/connection is considered to be
* at fault, if it returns false, the exception is just passed on to the
* caller.
*
* By default, returns true for instances of TTransportException and
* TApplicationException, false otherwise.
*/
bool delegate(Exception) rpcFaultFilter() const @property;
void rpcFaultFilter(bool delegate(Exception)) @property; /// Ditto
/**
* Whether to open the underlying transports of a client before trying to
* execute a method if they are not open. This is usually desirable
* because it allows e.g. to automatically reconnect to a remote server
* if the network connection is dropped.
*
* Defaults to true.
*/
bool reopenTransports() const @property;
void reopenTransports(bool) @property; /// Ditto
}
immutable bool delegate(Exception) defaultRpcFaultFilter;
static this() {
defaultRpcFaultFilter = (Exception e) {
import thrift.protocol.base;
import thrift.transport.base;
return (
(cast(TTransportException)e !is null) ||
(cast(TApplicationException)e !is null)
);
};
}
/**
* A TAsyncClientPoolBase implementation which queries multiple servers in a
* row until a request succeeds, the result of which is then returned.
*
* The definition of »success« can be customized using the rpcFaultFilter()
* delegate property. If it is non-null and calling it for an exception set by
* a failed method invocation returns true, the error is considered to be
* caused by the RPC layer rather than the application layer, and the next
* server in the pool is tried. If there are no more clients to try, the
* operation is marked as failed with a TCompoundOperationException.
*
* If a TAsyncClient in the pool fails with an RPC exception for a number of
* consecutive tries, it is temporarily disabled (not tried any longer) for
* a certain duration. Both the limit and the timeout can be configured. If all
* clients fail (and keepTrying is false), the operation fails with a
* TCompoundOperationException which contains the collected RPC exceptions.
*/
final class TAsyncClientPool(Interface) if (isService!Interface) :
TAsyncClientPoolBase!Interface
{
///
this(Client[] clients) {
pool_ = new TResourcePool!Client(clients);
rpcFaultFilter_ = defaultRpcFaultFilter;
reopenTransports_ = true;
}
/+override+/ void addClient(Client client) {
pool_.add(client);
}
/+override+/ bool removeClient(Client client) {
return pool_.remove(client);
}
/**
* Whether to keep trying to find a working client if all have failed in a
* row.
*
* Defaults to false.
*/
bool keepTrying() const @property {
return pool_.cycle;
}
/// Ditto
void keepTrying(bool value) @property {
pool_.cycle = value;
}
/**
* Whether to use a random permutation of the client pool on every call to
* execute(). This can be used e.g. as a simple form of load balancing.
*
* Defaults to true.
*/
bool permuteClients() const @property {
return pool_.permute;
}
/// Ditto
void permuteClients(bool value) @property {
pool_.permute = value;
}
/**
* The number of consecutive faults after which a client is disabled until
* faultDisableDuration has passed. 0 to never disable clients.
*
* Defaults to 0.
*/
ushort faultDisableCount() const @property {
return pool_.faultDisableCount;
}
/// Ditto
void faultDisableCount(ushort value) @property {
pool_.faultDisableCount = value;
}
/**
* The duration for which a client is no longer considered after it has
* failed too often.
*
* Defaults to one second.
*/
Duration faultDisableDuration() const @property {
return pool_.faultDisableDuration;
}
/// Ditto
void faultDisableDuration(Duration value) @property {
pool_.faultDisableDuration = value;
}
/+override+/ bool delegate(Exception) rpcFaultFilter() const @property {
return rpcFaultFilter_;
}
/+override+/ void rpcFaultFilter(bool delegate(Exception) value) @property {
rpcFaultFilter_ = value;
}
/+override+/ bool reopenTransports() const @property {
return reopenTransports_;
}
/+override+/ void reopenTransports(bool value) @property {
reopenTransports_ = value;
}
mixin(fallbackPoolForwardCode!Interface());
protected:
// The actual worker implementation to which RPC method calls are forwarded.
auto executeOnPool(string method, Args...)(Args args,
TCancellation cancellation
) {
auto clients = pool_[];
if (clients.empty) {
throw new TException("No clients available to try.");
}
auto promise = new TPromise!(ReturnType!(MemberType!(Interface, method)));
Exception[] rpcExceptions;
void tryNext() {
while (clients.empty) {
Client next;
Duration waitTime;
if (clients.willBecomeNonempty(next, waitTime)) {
if (waitTime > dur!"hnsecs"(0)) {
if (waitTime < dur!"usecs"(10)) {
import core.thread;
Thread.sleep(waitTime);
} else {
next.transport.asyncManager.delay(waitTime, { tryNext(); });
return;
}
}
} else {
promise.fail(new TCompoundOperationException("All clients failed.",
rpcExceptions));
return;
}
}
auto client = clients.front;
clients.popFront;
if (reopenTransports) {
if (!client.transport.isOpen) {
try {
client.transport.open();
} catch (Exception e) {
pool_.recordFault(client);
tryNext();
return;
}
}
}
auto future = mixin("client." ~ method)(args, cancellation);
future.completion.addCallback({
if (future.status == TFutureStatus.CANCELLED) {
promise.cancel();
return;
}
auto e = future.getException();
if (e) {
if (rpcFaultFilter_ && rpcFaultFilter_(e)) {
pool_.recordFault(client);
rpcExceptions ~= e;
tryNext();
return;
}
}
pool_.recordSuccess(client);
promise.complete(future);
});
}
tryNext();
return promise;
}
private:
TResourcePool!Client pool_;
bool delegate(Exception) rpcFaultFilter_;
bool reopenTransports_;
}
/**
* TAsyncClientPool construction helper to avoid having to explicitly
* specify the interface type, i.e. to allow the constructor being called
* using IFTI (see $(DMDBUG 6082, D Bugzilla enhancement request 6082)).
*/
TAsyncClientPool!Interface tAsyncClientPool(Interface)(
TAsyncClientBase!Interface[] clients
) if (isService!Interface) {
return new typeof(return)(clients);
}
private {
// Cannot use an anonymous delegate literal for this because they aren't
// allowed in class scope.
string fallbackPoolForwardCode(Interface)() {
string code = "";
foreach (methodName; AllMemberMethodNames!Interface) {
enum qn = "Interface." ~ methodName;
code ~= "TFuture!(ReturnType!(" ~ qn ~ ")) " ~ methodName ~
"(ParameterTypeTuple!(" ~ qn ~ ") args, TCancellation cancellation = null) {\n";
code ~= "return executeOnPool!(`" ~ methodName ~ "`)(args, cancellation);\n";
code ~= "}\n";
}
return code;
}
}
/**
* A TAsyncClientPoolBase implementation which queries multiple servers at
* the same time and returns the first success response.
*
* The definition of »success« can be customized using the rpcFaultFilter()
* delegate property. If it is non-null and calling it for an exception set by
* a failed method invocation returns true, the error is considered to be
* caused by the RPC layer rather than the application layer, and the next
* server in the pool is tried. If all clients fail, the operation is marked
* as failed with a TCompoundOperationException.
*/
final class TAsyncFastestClientPool(Interface) if (isService!Interface) :
TAsyncClientPoolBase!Interface
{
///
this(Client[] clients) {
clients_ = clients;
rpcFaultFilter_ = defaultRpcFaultFilter;
reopenTransports_ = true;
}
/+override+/ void addClient(Client client) {
clients_ ~= client;
}
/+override+/ bool removeClient(Client client) {
auto oldLength = clients_.length;
clients_ = removeEqual(clients_, client);
return clients_.length < oldLength;
}
/+override+/ bool delegate(Exception) rpcFaultFilter() const @property {
return rpcFaultFilter_;
}
/+override+/ void rpcFaultFilter(bool delegate(Exception) value) @property {
rpcFaultFilter_ = value;
}
/+override+/bool reopenTransports() const @property {
return reopenTransports_;
}
/+override+/ void reopenTransports(bool value) @property {
reopenTransports_ = value;
}
mixin(fastestPoolForwardCode!Interface());
private:
Client[] clients_;
bool delegate(Exception) rpcFaultFilter_;
bool reopenTransports_;
}
/**
* TAsyncFastestClientPool construction helper to avoid having to explicitly
* specify the interface type, i.e. to allow the constructor being called
* using IFTI (see $(DMDBUG 6082, D Bugzilla enhancement request 6082)).
*/
TAsyncFastestClientPool!Interface tAsyncFastestClientPool(Interface)(
TAsyncClientBase!Interface[] clients
) if (isService!Interface) {
return new typeof(return)(clients);
}
private {
// Cannot use an anonymous delegate literal for this because they aren't
// allowed in class scope.
string fastestPoolForwardCode(Interface)() {
string code = "";
foreach (methodName; AllMemberMethodNames!Interface) {
enum qn = "Interface." ~ methodName;
code ~= "TFuture!(ReturnType!(" ~ qn ~ ")) " ~ methodName ~
"(ParameterTypeTuple!(" ~ qn ~ ") args, " ~
"TCancellation cancellation = null) {\n";
code ~= "enum methodName = `" ~ methodName ~ "`;\n";
code ~= q{
alias ReturnType!(MemberType!(Interface, methodName)) ResultType;
auto childCancellation = new TCancellationOrigin;
TFuture!ResultType[] futures;
futures.reserve(clients_.length);
foreach (c; clients_) {
if (reopenTransports) {
if (!c.transport.isOpen) {
try {
c.transport.open();
} catch (Exception e) {
continue;
}
}
}
futures ~= mixin("c." ~ methodName)(args, childCancellation);
}
return new FastestPoolJob!(ResultType)(
futures, rpcFaultFilter, cancellation, childCancellation);
};
code ~= "}\n";
}
return code;
}
final class FastestPoolJob(Result) : TFuture!Result {
this(TFuture!Result[] poolFutures, bool delegate(Exception) rpcFaultFilter,
TCancellation cancellation, TCancellationOrigin childCancellation
) {
resultPromise_ = new TPromise!Result;
poolFutures_ = poolFutures;
rpcFaultFilter_ = rpcFaultFilter;
childCancellation_ = childCancellation;
foreach (future; poolFutures) {
future.completion.addCallback({
auto f = future;
return { completionCallback(f); };
}());
if (future.status != TFutureStatus.RUNNING) {
// If the current future is already completed, we are done, don't
// bother adding callbacks for the others (they would just return
// immediately after acquiring the lock).
return;
}
}
if (cancellation) {
cancellation.triggering.addCallback({
resultPromise_.cancel();
childCancellation.trigger();
});
}
}
TFutureStatus status() const @property {
return resultPromise_.status;
}
TAwaitable completion() @property {
return resultPromise_.completion;
}
Result get() {
return resultPromise_.get();
}
Exception getException() {
return resultPromise_.getException();
}
private:
void completionCallback(TFuture!Result future) {
synchronized {
if (future.status == TFutureStatus.CANCELLED) {
assert(resultPromise_.status != TFutureStatus.RUNNING);
return;
}
if (resultPromise_.status != TFutureStatus.RUNNING) {
// The operation has already been completed. This can happen if
// another client completed first, but this callback was already
// waiting for the lock when it called cancel().
return;
}
if (future.status == TFutureStatus.FAILED) {
auto e = future.getException();
if (rpcFaultFilter_ && rpcFaultFilter_(e)) {
rpcExceptions_ ~= e;
if (rpcExceptions_.length == poolFutures_.length) {
resultPromise_.fail(new TCompoundOperationException(
"All child operations failed, unable to retrieve a result.",
rpcExceptions_
));
}
return;
}
}
// Store the result to the target promise.
resultPromise_.complete(future);
// Cancel the other futures, we would just discard their results.
// Note: We do this after we have stored the results to our promise,
// see the assert at the top of the function.
childCancellation_.trigger();
}
}
TPromise!Result resultPromise_;
TFuture!Result[] poolFutures_;
Exception[] rpcExceptions_;
bool delegate(Exception) rpcFaultFilter_;
TCancellationOrigin childCancellation_;
}
}
/**
* Allows easily aggregating results from a number of TAsyncClients.
*
* Contrary to TAsync{Fallback, Fastest}ClientPool, this class does not
* simply implement TFutureInterface!Interface. It manages a pool of clients,
* but allows the user to specify a custom accumulator function to use or to
* iterate over the results using a TFutureAggregatorRange.
*
* For each service method, TAsyncAggregator offers a method
* accepting the same arguments, and an optional TCancellation instance, just
* like with TFutureInterface. The return type, however, is a proxy object
* that offers the following methods:
* ---
* /++
* + Returns a thrift.util.future.TFutureAggregatorRange for the results of
* + the client pool method invocations.
* +
* + The [] (slicing) operator can also be used to obtain the range.
* +
* + Params:
* + timeout = A timeout to pass to the TFutureAggregatorRange constructor,
* + defaults to zero (no timeout).
* +/
* TFutureAggregatorRange!ReturnType range(Duration timeout = dur!"hnsecs"(0));
* auto opSlice() { return range(); } /// Ditto
*
* /++
* + Returns a future that gathers the results from the clients in the pool
* + and invokes a user-supplied accumulator function on them, returning its
* + return value to the client.
* +
* + In addition to the TFuture!AccumulatedType interface (where
* + AccumulatedType is the return type of the accumulator function), the
* + returned object also offers two additional methods, finish() and
* + finishGet(): By default, the accumulator functions is called after all
* + the results from the pool clients have become available. Calling finish()
* + causes the accumulator future to stop waiting for other results and
* + immediately invoking the accumulator function on the results currently
* + available. If all results are already available, finish() is a no-op.
* + finishGet() is a convenience shortcut for combining it with
* + a call to get() immediately afterwards, like waitGet() is for wait().
* +
* + The acc alias can point to any callable accepting either an array of
* + return values or an array of return values and an array of exceptions;
* + see isAccumulator!() for details. The default accumulator concatenates
* + return values that can be concatenated with each others (e.g. arrays),
* + and simply returns an array of values otherwise, failing with a
* + TCompoundOperationException no values were returned.
* +
* + The accumulator function is not executed in any of the async manager
* + worker threads associated with the async clients, but instead it is
* + invoked when the actual result is requested for the first time after the
* + operation has been completed. This also includes checking the status
* + of the operation once it is no longer running, since the accumulator
* + has to be run to determine whether the operation succeeded or failed.
* +/
* auto accumulate(alias acc = defaultAccumulator)() if (isAccumulator!acc);
* ---
*
* Example:
* ---
* // Some Thrift service.
* interface Foo {
* int foo(string name);
* byte[] bar();
* }
*
* // Create the aggregator pool client0, client1, client2 are some
* // TAsyncClient!Foo instances, but in theory could also be other
* // TFutureInterface!Foo implementations (e.g. some async client pool).
* auto pool = new TAsyncAggregator!Foo([client0, client1, client2]);
*
* foreach (val; pool.foo("baz").range(dur!"seconds"(1))) {
* // Process all the results that are available before a second has passed,
* // in the order they arrive.
* writeln(val);
* }
*
* auto sumRoots = pool.foo("baz").accumulate!((int[] vals, Exceptions[] exs){
* if (vals.empty) {
* throw new TCompoundOperationException("All clients failed", exs);
* }
*
* // Just to illustrate that the type of the values can change, convert the
* // numbers to double and sum up their roots.
* double result = 0;
* foreach (v; vals) result += sqrt(cast(double)v);
* return result;
* })();
*
* // Wait up to three seconds for the result, and then accumulate what has
* // arrived so far.
* sumRoots.completion.wait(dur!"seconds"(3));
* writeln(sumRoots.finishGet());
*
* // For scalars, the default accumulator returns an array of the values.
* pragma(msg, typeof(pool.foo("").accumulate().get()); // int[].
*
* // For lists, etc., it concatenates the results together.
* pragma(msg, typeof(pool.bar().accumulate().get())); // byte[].
* ---
*
* Note: For the accumulate!() interface, you might currently hit a »cannot use
* local '…' as parameter to non-global template accumulate«-error, see
* $(DMDBUG 5710, DMD issue 5710). If your accumulator function does not need
* to access the surrounding scope, you might want to use a function literal
* instead of a delegate to avoid the issue.
*/
class TAsyncAggregator(Interface) if (isBaseService!Interface) {
/// Shorthand for the client type this instance operates on.
alias TAsyncClientBase!Interface Client;
///
this(Client[] clients) {
clients_ = clients;
}
/// Whether to open the underlying transports of a client before trying to
/// execute a method if they are not open. This is usually desirable
/// because it allows e.g. to automatically reconnect to a remote server
/// if the network connection is dropped.
///
/// Defaults to true.
bool reopenTransports = true;
mixin AggregatorOpDispatch!();
private:
Client[] clients_;
}
/// Ditto
class TAsyncAggregator(Interface) if (isDerivedService!Interface) :
TAsyncAggregator!(BaseService!Interface)
{
/// Shorthand for the client type this instance operates on.
alias TAsyncClientBase!Interface Client;
///
this(Client[] clients) {
super(cast(TAsyncClientBase!(BaseService!Interface)[])clients);
}
mixin AggregatorOpDispatch!();
}
/**
* Whether fun is a valid accumulator function for values of type ValueType.
*
* For this to be true, fun must be a callable matching one of the following
* argument lists:
* ---
* fun(ValueType[] values);
* fun(ValueType[] values, Exception[] exceptions);
* ---
*
* The second version is passed the collected array exceptions from all the
* clients in the pool.
*
* The return value of the accumulator function is passed to the client (via
* the result future). If it throws an exception, the operation is marked as
* failed with the given exception instead.
*/
template isAccumulator(ValueType, alias fun) {
enum isAccumulator = is(typeof(fun(cast(ValueType[])[]))) ||
is(typeof(fun(cast(ValueType[])[], cast(Exception[])[])));
}
/**
* TAsyncAggregator construction helper to avoid having to explicitly
* specify the interface type, i.e. to allow the constructor being called
* using IFTI (see $(DMDBUG 6082, D Bugzilla enhancement request 6082)).
*/
TAsyncAggregator!Interface tAsyncAggregator(Interface)(
TAsyncClientBase!Interface[] clients
) if (isService!Interface) {
return new typeof(return)(clients);
}
private {
mixin template AggregatorOpDispatch() {
auto opDispatch(string name, Args...)(Args args) if (
is(typeof(mixin("Interface.init." ~ name)(args)))
) {
alias ReturnType!(MemberType!(Interface, name)) ResultType;
auto childCancellation = new TCancellationOrigin;
TFuture!ResultType[] futures;
futures.reserve(clients_.length);
foreach (c; cast(Client[])clients_) {
if (reopenTransports) {
if (!c.transport.isOpen) {
try {
c.transport.open();
} catch (Exception e) {
continue;
}
}
}
futures ~= mixin("c." ~ name)(args, childCancellation);
}
return AggregationResult!ResultType(futures, childCancellation);
}
}
struct AggregationResult(T) {
auto opSlice() {
return range();
}
auto range(Duration timeout = dur!"hnsecs"(0)) {
return tFutureAggregatorRange(futures_, childCancellation_, timeout);
}
auto accumulate(alias acc = defaultAccumulator)() if (isAccumulator!(T, acc)) {
return new AccumulatorJob!(T, acc)(futures_, childCancellation_);
}
private:
TFuture!T[] futures_;
TCancellationOrigin childCancellation_;
}
auto defaultAccumulator(T)(T[] values, Exception[] exceptions) {
if (values.empty) {
throw new TCompoundOperationException("All clients failed",
exceptions);
}
static if (is(typeof(T.init ~ T.init))) {
import std.algorithm;
return reduce!"a ~ b"(values);
} else {
return values;
}
}
final class AccumulatorJob(T, alias accumulator) if (
isAccumulator!(T, accumulator)
) : TFuture!(AccumulatorResult!(T, accumulator)) {
this(TFuture!T[] futures, TCancellationOrigin childCancellation) {
futures_ = futures;
childCancellation_ = childCancellation;
resultMutex_ = new Mutex;
completionEvent_ = new TOneshotEvent;
foreach (future; futures) {
future.completion.addCallback({
auto f = future;
return {
synchronized (resultMutex_) {
if (f.status == TFutureStatus.CANCELLED) {
if (!finished_) {
status_ = TFutureStatus.CANCELLED;
finished_ = true;
}
return;
}
if (f.status == TFutureStatus.FAILED) {
exceptions_ ~= f.getException();
} else {
results_ ~= f.get();
}
if (results_.length + exceptions_.length == futures_.length) {
finished_ = true;
completionEvent_.trigger();
}
}
};
}());
}
}
TFutureStatus status() @property {
synchronized (resultMutex_) {
if (!finished_) return TFutureStatus.RUNNING;
if (status_ != TFutureStatus.RUNNING) return status_;
try {
result_ = invokeAccumulator!accumulator(results_, exceptions_);
status_ = TFutureStatus.SUCCEEDED;
} catch (Exception e) {
exception_ = e;
status_ = TFutureStatus.FAILED;
}
return status_;
}
}
TAwaitable completion() @property {
return completionEvent_;
}
AccumulatorResult!(T, accumulator) get() {
auto s = status;
enforce(s != TFutureStatus.RUNNING,
new TFutureException("Operation not yet completed."));
if (s == TFutureStatus.CANCELLED) throw new TCancelledException;
if (s == TFutureStatus.FAILED) throw exception_;
return result_;
}
Exception getException() {
auto s = status;
enforce(s != TFutureStatus.RUNNING,
new TFutureException("Operation not yet completed."));
if (s == TFutureStatus.CANCELLED) throw new TCancelledException;
if (s == TFutureStatus.SUCCEEDED) {
return null;
}
return exception_;
}
void finish() {
synchronized (resultMutex_) {
if (!finished_) {
finished_ = true;
childCancellation_.trigger();
completionEvent_.trigger();
}
}
}
auto finishGet() {
finish();
return get();
}
private:
TFuture!T[] futures_;
TCancellationOrigin childCancellation_;
bool finished_;
T[] results_;
Exception[] exceptions_;
TFutureStatus status_;
Mutex resultMutex_;
union {
AccumulatorResult!(T, accumulator) result_;
Exception exception_;
}
TOneshotEvent completionEvent_;
}
auto invokeAccumulator(alias accumulator, T)(
T[] values, Exception[] exceptions
) if (
isAccumulator!(T, accumulator)
) {
static if (is(typeof(accumulator(values, exceptions)))) {
return accumulator(values, exceptions);
} else {
return accumulator(values);
}
}
template AccumulatorResult(T, alias acc) {
alias typeof(invokeAccumulator!acc(cast(T[])[], cast(Exception[])[]))
AccumulatorResult;
}
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,486 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.codegen.client;
import std.algorithm : find;
import std.array : empty, front;
import std.conv : to;
import std.traits : isSomeFunction, ParameterStorageClass,
ParameterStorageClassTuple, ParameterTypeTuple, ReturnType;
import thrift.codegen.base;
import thrift.internal.codegen;
import thrift.internal.ctfe;
import thrift.protocol.base;
/**
* Thrift service client, which implements an interface by synchronously
* calling a server over a TProtocol.
*
* TClientBase simply extends Interface with generic input/output protocol
* properties to serve as a supertype for all TClients for the same service,
* which might be instantiated with different concrete protocol types (there
* is no covariance for template type parameters). If Interface is derived
* from another interface BaseInterface, it also extends
* TClientBase!BaseInterface.
*
* TClient is the class that actually implements TClientBase. Just as
* TClientBase, it is also derived from TClient!BaseInterface for inheriting
* services.
*
* TClient takes two optional template arguments which can be used for
* specifying the actual TProtocol implementation used for optimization
* purposes, as virtual calls can completely be eliminated then. If
* OutputProtocol is not specified, it is assumed to be the same as
* InputProtocol. The protocol properties defined by TClientBase are exposed
* with their concrete type (return type covariance).
*
* In addition to implementing TClientBase!Interface, TClient offers the
* following constructors:
* ---
* this(InputProtocol iprot, OutputProtocol oprot);
* // Only if is(InputProtocol == OutputProtocol), to use the same protocol
* // for both input and output:
* this(InputProtocol prot);
* ---
*
* The sequence id of the method calls starts at zero and is automatically
* incremented.
*/
interface TClientBase(Interface) if (isBaseService!Interface) : Interface {
/**
* The input protocol used by the client.
*/
TProtocol inputProtocol() @property;
/**
* The output protocol used by the client.
*/
TProtocol outputProtocol() @property;
}
/// Ditto
interface TClientBase(Interface) if (isDerivedService!Interface) :
TClientBase!(BaseService!Interface), Interface {}
/// Ditto
template TClient(Interface, InputProtocol = TProtocol, OutputProtocol = void) if (
isService!Interface && isTProtocol!InputProtocol &&
(isTProtocol!OutputProtocol || is(OutputProtocol == void))
) {
mixin({
static if (isDerivedService!Interface) {
string code = "class TClient : TClient!(BaseService!Interface, " ~
"InputProtocol, OutputProtocol), TClientBase!Interface {\n";
code ~= q{
this(IProt iprot, OProt oprot) {
super(iprot, oprot);
}
static if (is(IProt == OProt)) {
this(IProt prot) {
super(prot);
}
}
// DMD @@BUG@@: If these are not present in this class (would be)
// inherited anyway, »not implemented« errors are raised.
override IProt inputProtocol() @property {
return super.inputProtocol;
}
override OProt outputProtocol() @property {
return super.outputProtocol;
}
};
} else {
string code = "class TClient : TClientBase!Interface {";
code ~= q{
alias InputProtocol IProt;
static if (isTProtocol!OutputProtocol) {
alias OutputProtocol OProt;
} else {
static assert(is(OutputProtocol == void));
alias InputProtocol OProt;
}
this(IProt iprot, OProt oprot) {
iprot_ = iprot;
oprot_ = oprot;
}
static if (is(IProt == OProt)) {
this(IProt prot) {
this(prot, prot);
}
}
IProt inputProtocol() @property {
return iprot_;
}
OProt outputProtocol() @property {
return oprot_;
}
protected IProt iprot_;
protected OProt oprot_;
protected int seqid_;
};
}
foreach (methodName; __traits(derivedMembers, Interface)) {
static if (isSomeFunction!(mixin("Interface." ~ methodName))) {
bool methodMetaFound;
TMethodMeta methodMeta;
static if (is(typeof(Interface.methodMeta) : TMethodMeta[])) {
enum meta = find!`a.name == b`(Interface.methodMeta, methodName);
if (!meta.empty) {
methodMetaFound = true;
methodMeta = meta.front;
}
}
// Generate the code for sending.
string[] paramList;
string paramAssignCode;
foreach (i, _; ParameterTypeTuple!(mixin("Interface." ~ methodName))) {
// Use the param name speficied in the meta information if any
// just cosmetics in this case.
string paramName;
if (methodMetaFound && i < methodMeta.params.length) {
paramName = methodMeta.params[i].name;
} else {
paramName = "param" ~ to!string(i + 1);
}
immutable storage = ParameterStorageClassTuple!(
mixin("Interface." ~ methodName))[i];
paramList ~= ((storage & ParameterStorageClass.ref_) ? "ref " : "") ~
"ParameterTypeTuple!(Interface." ~ methodName ~ ")[" ~
to!string(i) ~ "] " ~ paramName;
paramAssignCode ~= "args." ~ paramName ~ " = &" ~ paramName ~ ";\n";
}
code ~= "ReturnType!(Interface." ~ methodName ~ ") " ~ methodName ~
"(" ~ ctfeJoin(paramList) ~ ") {\n";
code ~= "immutable methodName = `" ~ methodName ~ "`;\n";
immutable paramStructType =
"TPargsStruct!(Interface, `" ~ methodName ~ "`)";
code ~= paramStructType ~ " args = " ~ paramStructType ~ "();\n";
code ~= paramAssignCode;
code ~= "oprot_.writeMessageBegin(TMessage(`" ~ methodName ~ "`, ";
code ~= ((methodMetaFound && methodMeta.type == TMethodType.ONEWAY)
? "TMessageType.ONEWAY" : "TMessageType.CALL");
code ~= ", ++seqid_));\n";
code ~= "args.write(oprot_);\n";
code ~= "oprot_.writeMessageEnd();\n";
code ~= "oprot_.transport.flush();\n";
// If this is not a oneway method, generate the receiving code.
if (!methodMetaFound || methodMeta.type != TMethodType.ONEWAY) {
code ~= "TPresultStruct!(Interface, `" ~ methodName ~ "`) result;\n";
if (!is(ReturnType!(mixin("Interface." ~ methodName)) == void)) {
code ~= "ReturnType!(Interface." ~ methodName ~ ") _return;\n";
code ~= "result.success = &_return;\n";
}
// TODO: The C++ implementation checks for matching method name here,
// should we do as well?
code ~= q{
auto msg = iprot_.readMessageBegin();
scope (exit) {
iprot_.readMessageEnd();
iprot_.transport.readEnd();
}
if (msg.type == TMessageType.EXCEPTION) {
auto x = new TApplicationException(null);
x.read(iprot_);
iprot_.transport.readEnd();
throw x;
}
if (msg.type != TMessageType.REPLY) {
skip(iprot_, TType.STRUCT);
iprot_.transport.readEnd();
}
if (msg.seqid != seqid_) {
throw new TApplicationException(
methodName ~ " failed: Out of sequence response.",
TApplicationException.Type.BAD_SEQUENCE_ID
);
}
result.read(iprot_);
};
if (methodMetaFound) {
foreach (e; methodMeta.exceptions) {
code ~= "if (result.isSet!`" ~ e.name ~ "`) throw result." ~
e.name ~ ";\n";
}
}
if (!is(ReturnType!(mixin("Interface." ~ methodName)) == void)) {
code ~= q{
if (result.isSet!`success`) return _return;
throw new TApplicationException(
methodName ~ " failed: Unknown result.",
TApplicationException.Type.MISSING_RESULT
);
};
}
}
code ~= "}\n";
}
}
code ~= "}\n";
return code;
}());
}
/**
* TClient construction helper to avoid having to explicitly specify
* the protocol types, i.e. to allow the constructor being called using IFTI
* (see $(DMDBUG 6082, D Bugzilla enhancement requet 6082)).
*/
TClient!(Interface, Prot) tClient(Interface, Prot)(Prot prot) if (
isService!Interface && isTProtocol!Prot
) {
return new TClient!(Interface, Prot)(prot);
}
/// Ditto
TClient!(Interface, IProt, Oprot) tClient(Interface, IProt, OProt)
(IProt iprot, OProt oprot) if (
isService!Interface && isTProtocol!IProt && isTProtocol!OProt
) {
return new TClient!(Interface, IProt, OProt)(iprot, oprot);
}
/**
* Represents the arguments of a Thrift method call, as pointers to the (const)
* parameter type to avoid copying.
*
* There should usually be no reason to use this struct directly without the
* help of TClient, but it is documented publicly to help debugging in case
* of CTFE errors.
*
* Consider this example:
* ---
* interface Foo {
* int bar(string a, bool b);
*
* enum methodMeta = [
* TMethodMeta("bar", [TParamMeta("a", 1), TParamMeta("b", 2)])
* ];
* }
*
* alias TPargsStruct!(Foo, "bar") FooBarPargs;
* ---
*
* The definition of FooBarPargs is equivalent to (ignoring the necessary
* metadata to assign the field IDs):
* ---
* struct FooBarPargs {
* const(string)* a;
* const(bool)* b;
*
* void write(Protocol)(Protocol proto) const if (isTProtocol!Protocol);
* }
* ---
*/
template TPargsStruct(Interface, string methodName) {
static assert(is(typeof(mixin("Interface." ~ methodName))),
"Could not find method '" ~ methodName ~ "' in '" ~ Interface.stringof ~ "'.");
mixin({
bool methodMetaFound;
TMethodMeta methodMeta;
static if (is(typeof(Interface.methodMeta) : TMethodMeta[])) {
auto meta = find!`a.name == b`(Interface.methodMeta, methodName);
if (!meta.empty) {
methodMetaFound = true;
methodMeta = meta.front;
}
}
string memberCode;
string[] fieldMetaCodes;
foreach (i, _; ParameterTypeTuple!(mixin("Interface." ~ methodName))) {
// If we have no meta information, just use param1, param2, etc. as
// field names, it shouldn't really matter anyway. 1-based »indexing«
// is used to match the common scheme in the Thrift world.
string memberId;
string memberName;
if (methodMetaFound && i < methodMeta.params.length) {
memberId = to!string(methodMeta.params[i].id);
memberName = methodMeta.params[i].name;
} else {
memberId = to!string(i + 1);
memberName = "param" ~ to!string(i + 1);
}
// Workaround for DMD @@BUG@@ 6056: make an intermediary alias for the
// parameter type, and declare the member using const(memberNameType)*.
memberCode ~= "alias ParameterTypeTuple!(Interface." ~ methodName ~
")[" ~ to!string(i) ~ "] " ~ memberName ~ "Type;\n";
memberCode ~= "const(" ~ memberName ~ "Type)* " ~ memberName ~ ";\n";
fieldMetaCodes ~= "TFieldMeta(`" ~ memberName ~ "`, " ~ memberId ~
", TReq.OPT_IN_REQ_OUT)";
}
string code = "struct TPargsStruct {\n";
code ~= memberCode;
version (TVerboseCodegen) {
if (!methodMetaFound &&
ParameterTypeTuple!(mixin("Interface." ~ methodName)).length > 0)
{
code ~= "pragma(msg, `[thrift.codegen.base.TPargsStruct] Warning: No " ~
"meta information for method '" ~ methodName ~ "' in service '" ~
Interface.stringof ~ "' found.`);\n";
}
}
code ~= "void write(P)(P proto) const if (isTProtocol!P) {\n";
code ~= "writeStruct!(typeof(this), P, [" ~ ctfeJoin(fieldMetaCodes) ~
"], true)(this, proto);\n";
code ~= "}\n";
code ~= "}\n";
return code;
}());
}
/**
* Represents the result of a Thrift method call, using a pointer to the return
* value to avoid copying.
*
* There should usually be no reason to use this struct directly without the
* help of TClient, but it is documented publicly to help debugging in case
* of CTFE errors.
*
* Consider this example:
* ---
* interface Foo {
* int bar(string a);
*
* alias .FooException FooException;
*
* enum methodMeta = [
* TMethodMeta("bar",
* [TParamMeta("a", 1)],
* [TExceptionMeta("fooe", 1, "FooException")]
* )
* ];
* }
* alias TPresultStruct!(Foo, "bar") FooBarPresult;
* ---
*
* The definition of FooBarPresult is equivalent to (ignoring the necessary
* metadata to assign the field IDs):
* ---
* struct FooBarPresult {
* int* success;
* Foo.FooException fooe;
*
* struct IsSetFlags {
* bool success;
* }
* IsSetFlags isSetFlags;
*
* bool isSet(string fieldName)() const @property;
* void read(Protocol)(Protocol proto) if (isTProtocol!Protocol);
* }
* ---
*/
template TPresultStruct(Interface, string methodName) {
static assert(is(typeof(mixin("Interface." ~ methodName))),
"Could not find method '" ~ methodName ~ "' in '" ~ Interface.stringof ~ "'.");
mixin({
string code = "struct TPresultStruct {\n";
string[] fieldMetaCodes;
alias ReturnType!(mixin("Interface." ~ methodName)) ResultType;
static if (!is(ResultType == void)) {
code ~= q{
ReturnType!(mixin("Interface." ~ methodName))* success;
};
fieldMetaCodes ~= "TFieldMeta(`success`, 0, TReq.OPTIONAL)";
static if (!isNullable!ResultType) {
code ~= q{
struct IsSetFlags {
bool success;
}
IsSetFlags isSetFlags;
};
fieldMetaCodes ~= "TFieldMeta(`isSetFlags`, 0, TReq.IGNORE)";
}
}
bool methodMetaFound;
static if (is(typeof(Interface.methodMeta) : TMethodMeta[])) {
auto meta = find!`a.name == b`(Interface.methodMeta, methodName);
if (!meta.empty) {
foreach (e; meta.front.exceptions) {
code ~= "Interface." ~ e.type ~ " " ~ e.name ~ ";\n";
fieldMetaCodes ~= "TFieldMeta(`" ~ e.name ~ "`, " ~ to!string(e.id) ~
", TReq.OPTIONAL)";
}
methodMetaFound = true;
}
}
version (TVerboseCodegen) {
if (!methodMetaFound &&
ParameterTypeTuple!(mixin("Interface." ~ methodName)).length > 0)
{
code ~= "pragma(msg, `[thrift.codegen.base.TPresultStruct] Warning: No " ~
"meta information for method '" ~ methodName ~ "' in service '" ~
Interface.stringof ~ "' found.`);\n";
}
}
code ~= q{
bool isSet(string fieldName)() const @property if (
is(MemberType!(typeof(this), fieldName))
) {
static if (fieldName == "success") {
static if (isNullable!(typeof(*success))) {
return *success !is null;
} else {
return isSetFlags.success;
}
} else {
// We are dealing with an exception member, which, being a nullable
// type (exceptions are always classes), has no isSet flag.
return __traits(getMember, this, fieldName) !is null;
}
}
};
code ~= "void read(P)(P proto) if (isTProtocol!P) {\n";
code ~= "readStruct!(typeof(this), P, [" ~ ctfeJoin(fieldMetaCodes) ~
"], true)(this, proto);\n";
code ~= "}\n";
code ~= "}\n";
return code;
}());
}

View file

@ -0,0 +1,262 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.codegen.client_pool;
import core.time : dur, Duration, TickDuration;
import std.traits : ParameterTypeTuple, ReturnType;
import thrift.base;
import thrift.codegen.base;
import thrift.codegen.client;
import thrift.internal.codegen;
import thrift.internal.resource_pool;
/**
* Manages a pool of TClients for the given interface, forwarding RPC calls to
* members of the pool.
*
* If a request fails, another client from the pool is tried, and optionally,
* a client is disabled for a configurable amount of time if it fails too
* often. If all clients fail (and keepTrying is false), a
* TCompoundOperationException is thrown, containing all the collected RPC
* exceptions.
*/
class TClientPool(Interface) if (isService!Interface) : Interface {
/// Shorthand for TClientBase!Interface, the client type this instance
/// operates on.
alias TClientBase!Interface Client;
/**
* Creates a new instance and adds the given clients to the pool.
*/
this(Client[] clients) {
pool_ = new TResourcePool!Client(clients);
rpcFaultFilter = (Exception e) {
import thrift.protocol.base;
import thrift.transport.base;
return (
(cast(TTransportException)e !is null) ||
(cast(TApplicationException)e !is null)
);
};
}
/**
* Executes an operation on the first currently active client.
*
* If the operation fails (throws an exception for which rpcFaultFilter is
* true), the failure is recorded and the next client in the pool is tried.
*
* Throws: Any non-rpc exception that occurs, a TCompoundOperationException
* if all clients failed with an rpc exception (if keepTrying is false).
*
* Example:
* ---
* interface Foo { string bar(); }
* auto poolClient = tClientPool([tClient!Foo(someProtocol)]);
* auto result = poolClient.execute((c){ return c.bar(); });
* ---
*/
ResultType execute(ResultType)(scope ResultType delegate(Client) work) {
return executeOnPool!Client(work);
}
/**
* Adds a client to the pool.
*/
void addClient(Client client) {
pool_.add(client);
}
/**
* Removes a client from the pool.
*
* Returns: Whether the client was found in the pool.
*/
bool removeClient(Client client) {
return pool_.remove(client);
}
mixin(poolForwardCode!Interface());
/// Whether to open the underlying transports of a client before trying to
/// execute a method if they are not open. This is usually desirable
/// because it allows e.g. to automatically reconnect to a remote server
/// if the network connection is dropped.
///
/// Defaults to true.
bool reopenTransports = true;
/// Called to determine whether an exception comes from a client from the
/// pool not working properly, or if it an exception thrown at the
/// application level.
///
/// If the delegate returns true, the server/connection is considered to be
/// at fault, if it returns false, the exception is just passed on to the
/// caller.
///
/// By default, returns true for instances of TTransportException and
/// TApplicationException, false otherwise.
bool delegate(Exception) rpcFaultFilter;
/**
* Whether to keep trying to find a working client if all have failed in a
* row.
*
* Defaults to false.
*/
bool keepTrying() const @property {
return pool_.cycle;
}
/// Ditto
void keepTrying(bool value) @property {
pool_.cycle = value;
}
/**
* Whether to use a random permutation of the client pool on every call to
* execute(). This can be used e.g. as a simple form of load balancing.
*
* Defaults to true.
*/
bool permuteClients() const @property {
return pool_.permute;
}
/// Ditto
void permuteClients(bool value) @property {
pool_.permute = value;
}
/**
* The number of consecutive faults after which a client is disabled until
* faultDisableDuration has passed. 0 to never disable clients.
*
* Defaults to 0.
*/
ushort faultDisableCount() @property {
return pool_.faultDisableCount;
}
/// Ditto
void faultDisableCount(ushort value) @property {
pool_.faultDisableCount = value;
}
/**
* The duration for which a client is no longer considered after it has
* failed too often.
*
* Defaults to one second.
*/
Duration faultDisableDuration() @property {
return pool_.faultDisableDuration;
}
/// Ditto
void faultDisableDuration(Duration value) @property {
pool_.faultDisableDuration = value;
}
protected:
ResultType executeOnPool(ResultType)(scope ResultType delegate(Client) work) {
auto clients = pool_[];
if (clients.empty) {
throw new TException("No clients available to try.");
}
while (true) {
Exception[] rpcExceptions;
while (!clients.empty) {
auto c = clients.front;
clients.popFront;
try {
scope (success) {
pool_.recordSuccess(c);
}
if (reopenTransports) {
c.inputProtocol.transport.open();
c.outputProtocol.transport.open();
}
return work(c);
} catch (Exception e) {
if (rpcFaultFilter && rpcFaultFilter(e)) {
pool_.recordFault(c);
rpcExceptions ~= e;
} else {
// We are dealing with a normal exception thrown by the
// server-side method, just pass it on. As far as we are
// concerned, the method call succeeded.
pool_.recordSuccess(c);
throw e;
}
}
}
// If we get here, no client succeeded during the current iteration.
Duration waitTime;
Client dummy;
if (clients.willBecomeNonempty(dummy, waitTime)) {
if (waitTime > dur!"hnsecs"(0)) {
import core.thread;
Thread.sleep(waitTime);
}
} else {
throw new TCompoundOperationException("All clients failed.",
rpcExceptions);
}
}
}
private:
TResourcePool!Client pool_;
}
private {
// Cannot use an anonymous delegate literal for this because they aren't
// allowed in class scope.
string poolForwardCode(Interface)() {
string code = "";
foreach (methodName; AllMemberMethodNames!Interface) {
enum qn = "Interface." ~ methodName;
code ~= "ReturnType!(" ~ qn ~ ") " ~ methodName ~
"(ParameterTypeTuple!(" ~ qn ~ ") args) {\n";
code ~= "return executeOnPool((Client c){ return c." ~
methodName ~ "(args); });\n";
code ~= "}\n";
}
return code;
}
}
/**
* TClientPool construction helper to avoid having to explicitly specify
* the interface type, i.e. to allow the constructor being called using IFTI
* (see $(DMDBUG 6082, D Bugzilla enhancement requet 6082)).
*/
TClientPool!Interface tClientPool(Interface)(
TClientBase!Interface[] clients
) if (isService!Interface) {
return new typeof(return)(clients);
}

View file

@ -0,0 +1,770 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/**
* Contains <b>experimental</b> functionality for generating Thrift IDL files
* (.thrift) from existing D data structures, i.e. the reverse of what the
* Thrift compiler does.
*/
module thrift.codegen.idlgen;
import std.algorithm : find;
import std.array : empty, front;
import std.conv : to;
import std.traits : EnumMembers, isSomeFunction, OriginalType,
ParameterTypeTuple, ReturnType;
import std.typetuple : allSatisfy, staticIndexOf, staticMap, NoDuplicates,
TypeTuple;
import thrift.base;
import thrift.codegen.base;
import thrift.internal.codegen;
import thrift.internal.ctfe;
import thrift.util.hashset;
/**
* True if the passed type is a Thrift entity (struct, exception, enum,
* service).
*/
alias Any!(isStruct, isException, isEnum, isService) isThriftEntity;
/**
* Returns an IDL string describing the passed »root« entities and all types
* they depend on.
*/
template idlString(Roots...) if (allSatisfy!(isThriftEntity, Roots)) {
enum idlString = idlStringImpl!Roots.result;
}
private {
template idlStringImpl(Roots...) if (allSatisfy!(isThriftEntity, Roots)) {
alias ForAllWithList!(
ConfinedTuple!(StaticFilter!(isService, Roots)),
AddBaseServices
) Services;
alias TypeTuple!(
StaticFilter!(isEnum, Roots),
ForAllWithList!(
ConfinedTuple!(
StaticFilter!(Any!(isException, isStruct), Roots),
staticMap!(CompositeTypeDeps, staticMap!(ServiceTypeDeps, Services))
),
AddStructWithDeps
)
) Types;
enum result = ctfeJoin(
[
staticMap!(
enumIdlString,
StaticFilter!(isEnum, Types)
),
staticMap!(
structIdlString,
StaticFilter!(Any!(isStruct, isException), Types)
),
staticMap!(
serviceIdlString,
Services
)
],
"\n"
);
}
template ServiceTypeDeps(T) if (isService!T) {
alias staticMap!(
PApply!(MethodTypeDeps, T),
FilterMethodNames!(T, __traits(derivedMembers, T))
) ServiceTypeDeps;
}
template MethodTypeDeps(T, string name) if (
isService!T && isSomeFunction!(MemberType!(T, name))
) {
alias TypeTuple!(
ReturnType!(MemberType!(T, name)),
ParameterTypeTuple!(MemberType!(T, name)),
ExceptionTypes!(T, name)
) MethodTypeDeps;
}
template ExceptionTypes(T, string name) if (
isService!T && isSomeFunction!(MemberType!(T, name))
) {
mixin({
enum meta = find!`a.name == b`(getMethodMeta!T, name);
if (meta.empty) return "alias TypeTuple!() ExceptionTypes;";
string result = "alias TypeTuple!(";
foreach (i, e; meta.front.exceptions) {
if (i > 0) result ~= ", ";
result ~= "mixin(`T." ~ e.type ~ "`)";
}
result ~= ") ExceptionTypes;";
return result;
}());
}
template AddBaseServices(T, List...) {
static if (staticIndexOf!(T, List) == -1) {
alias NoDuplicates!(BaseServices!T, List) AddBaseServices;
} else {
alias List AddStructWithDeps;
}
}
unittest {
interface A {}
interface B : A {}
interface C : B {}
interface D : A {}
static assert(is(AddBaseServices!(C) == TypeTuple!(A, B, C)));
static assert(is(ForAllWithList!(ConfinedTuple!(C, D), AddBaseServices) ==
TypeTuple!(A, D, B, C)));
}
template BaseServices(T, Rest...) if (isService!T) {
static if (isDerivedService!T) {
alias BaseServices!(BaseService!T, T, Rest) BaseServices;
} else {
alias TypeTuple!(T, Rest) BaseServices;
}
}
template AddStructWithDeps(T, List...) {
static if (staticIndexOf!(T, List) == -1) {
// T is not already in the List, so add T and the types it depends on in
// the front. Because with the Thrift compiler types can only depend on
// other types that have already been defined, we collect all the
// dependencies, prepend them to the list, and then prune the duplicates
// (keeping the first occurrences). If this requirement should ever be
// dropped from Thrift, this could be easily adapted to handle circular
// dependencies by passing TypeTuple!(T, List) to ForAllWithList instead
// of appending List afterwards, and removing the now unnecessary
// NoDuplicates.
alias NoDuplicates!(
ForAllWithList!(
ConfinedTuple!(
staticMap!(
CompositeTypeDeps,
staticMap!(
PApply!(MemberType, T),
FieldNames!T
)
)
),
.AddStructWithDeps,
T
),
List
) AddStructWithDeps;
} else {
alias List AddStructWithDeps;
}
}
version (unittest) {
struct A {}
struct B {
A a;
int b;
A c;
string d;
}
struct C {
B b;
A a;
}
static assert(is(AddStructWithDeps!C == TypeTuple!(A, B, C)));
struct D {
C c;
mixin TStructHelpers!([TFieldMeta("c", 0, TReq.IGNORE)]);
}
static assert(is(AddStructWithDeps!D == TypeTuple!(D)));
}
version (unittest) {
// Circles in the type dependency graph are not allowed in Thrift, but make
// sure we fail in a sane way instead of crashing the compiler.
struct Rec1 {
Rec2[] other;
}
struct Rec2 {
Rec1[] other;
}
static assert(!__traits(compiles, AddStructWithDeps!Rec1));
}
/*
* Returns the non-primitive types T directly depends on.
*
* For example, CompositeTypeDeps!int would yield an empty type tuple,
* CompositeTypeDeps!SomeStruct would give SomeStruct, and
* CompositeTypeDeps!(A[B]) both CompositeTypeDeps!A and CompositeTypeDeps!B.
*/
template CompositeTypeDeps(T) {
static if (is(FullyUnqual!T == bool) || is(FullyUnqual!T == byte) ||
is(FullyUnqual!T == short) || is(FullyUnqual!T == int) ||
is(FullyUnqual!T == long) || is(FullyUnqual!T : string) ||
is(FullyUnqual!T == double) || is(FullyUnqual!T == void)
) {
alias TypeTuple!() CompositeTypeDeps;
} else static if (is(FullyUnqual!T _ : U[], U)) {
alias CompositeTypeDeps!U CompositeTypeDeps;
} else static if (is(FullyUnqual!T _ : HashSet!E, E)) {
alias CompositeTypeDeps!E CompositeTypeDeps;
} else static if (is(FullyUnqual!T _ : V[K], K, V)) {
alias TypeTuple!(CompositeTypeDeps!K, CompositeTypeDeps!V) CompositeTypeDeps;
} else static if (is(FullyUnqual!T == enum) || is(FullyUnqual!T == struct) ||
is(FullyUnqual!T : TException)
) {
alias TypeTuple!(FullyUnqual!T) CompositeTypeDeps;
} else {
static assert(false, "Cannot represent type in Thrift: " ~ T.stringof);
}
}
}
/**
* Returns an IDL string describing the passed service. IDL code for any type
* dependcies is not included.
*/
template serviceIdlString(T) if (isService!T) {
enum serviceIdlString = {
string result = "service " ~ T.stringof;
static if (isDerivedService!T) {
result ~= " extends " ~ BaseService!T.stringof;
}
result ~= " {\n";
foreach (methodName; FilterMethodNames!(T, __traits(derivedMembers, T))) {
result ~= " ";
enum meta = find!`a.name == b`(T.methodMeta, methodName);
static if (!meta.empty && meta.front.type == TMethodType.ONEWAY) {
result ~= "oneway ";
}
alias ReturnType!(MemberType!(T, methodName)) RT;
static if (is(RT == void)) {
// We special-case this here instead of adding void to dToIdlType to
// avoid accepting things like void[].
result ~= "void ";
} else {
result ~= dToIdlType!RT ~ " ";
}
result ~= methodName ~ "(";
short lastId;
foreach (i, ParamType; ParameterTypeTuple!(MemberType!(T, methodName))) {
static if (!meta.empty && i < meta.front.params.length) {
enum havePM = true;
} else {
enum havePM = false;
}
short id;
static if (havePM) {
id = meta.front.params[i].id;
} else {
id = --lastId;
}
string paramName;
static if (havePM) {
paramName = meta.front.params[i].name;
} else {
paramName = "param" ~ to!string(i + 1);
}
result ~= to!string(id) ~ ": " ~ dToIdlType!ParamType ~ " " ~ paramName;
static if (havePM && !meta.front.params[i].defaultValue.empty) {
result ~= " = " ~ dToIdlConst(mixin(meta.front.params[i].defaultValue));
} else {
// Unfortunately, getting the default value for parameters from a
// function alias isn't possible we can't transfer the default
// value to the IDL e.g. for interface Foo { void foo(int a = 5); }
// without the user explicitly declaring it in metadata.
}
result ~= ", ";
}
result ~= ")";
static if (!meta.empty && !meta.front.exceptions.empty) {
result ~= " throws (";
foreach (e; meta.front.exceptions) {
result ~= to!string(e.id) ~ ": " ~ e.type ~ " " ~ e.name ~ ", ";
}
result ~= ")";
}
result ~= ",\n";
}
result ~= "}\n";
return result;
}();
}
/**
* Returns an IDL string describing the passed enum. IDL code for any type
* dependcies is not included.
*/
template enumIdlString(T) if (isEnum!T) {
enum enumIdlString = {
static assert(is(OriginalType!T : long),
"Can only have integer enums in Thrift (not " ~ OriginalType!T.stringof ~
", for " ~ T.stringof ~ ").");
string result = "enum " ~ T.stringof ~ " {\n";
foreach (name; __traits(derivedMembers, T)) {
result ~= " " ~ name ~ " = " ~ dToIdlConst(GetMember!(T, name)) ~ ",\n";
}
result ~= "}\n";
return result;
}();
}
/**
* Returns an IDL string describing the passed struct. IDL code for any type
* dependcies is not included.
*/
template structIdlString(T) if (isStruct!T || isException!T) {
enum structIdlString = {
mixin({
string code = "";
foreach (field; getFieldMeta!T) {
code ~= "static assert(is(MemberType!(T, `" ~ field.name ~ "`)));\n";
}
return code;
}());
string result;
static if (isException!T) {
result = "exception ";
} else {
result = "struct ";
}
result ~= T.stringof ~ " {\n";
// The last automatically assigned id fields with no meta information
// are assigned (in lexical order) descending negative ids, starting with
// -1, just like the Thrift compiler does.
short lastId;
foreach (name; FieldNames!T) {
enum meta = find!`a.name == b`(getFieldMeta!T, name);
static if (meta.empty || meta.front.req != TReq.IGNORE) {
short id;
static if (meta.empty) {
id = --lastId;
} else {
id = meta.front.id;
}
result ~= " " ~ to!string(id) ~ ":";
static if (!meta.empty) {
result ~= dToIdlReq(meta.front.req);
}
result ~= " " ~ dToIdlType!(MemberType!(T, name)) ~ " " ~ name;
static if (!meta.empty && !meta.front.defaultValue.empty) {
result ~= " = " ~ dToIdlConst(mixin(meta.front.defaultValue));
} else static if (__traits(compiles, fieldInitA!(T, name))) {
static if (is(typeof(fieldInitA!(T, name))) &&
!is(typeof(fieldInitA!(T, name)) == void)
) {
result ~= " = " ~ dToIdlConst(fieldInitA!(T, name));
}
} else static if (is(typeof(fieldInitB!(T, name))) &&
!is(typeof(fieldInitB!(T, name)) == void)
) {
result ~= " = " ~ dToIdlConst(fieldInitB!(T, name));
}
result ~= ",\n";
}
}
result ~= "}\n";
return result;
}();
}
private {
// This very convoluted way of doing things was chosen because putting the
// static if directly into structIdlString caused »not evaluatable at compile
// time« errors to slip through even though typeof() was used, resp. the
// condition to be true even though the value couldn't actually be read at
// compile time due to a @@BUG@@ in DMD 2.055.
// The extra »compiled« field in fieldInitA is needed because we must not try
// to use != if !is compiled as well (but was false), e.g. for floating point
// types.
template fieldInitA(T, string name) {
static if (mixin("T.init." ~ name) !is MemberType!(T, name).init) {
enum fieldInitA = mixin("T.init." ~ name);
}
}
template fieldInitB(T, string name) {
static if (mixin("T.init." ~ name) != MemberType!(T, name).init) {
enum fieldInitB = mixin("T.init." ~ name);
}
}
template dToIdlType(T) {
static if (is(FullyUnqual!T == bool)) {
enum dToIdlType = "bool";
} else static if (is(FullyUnqual!T == byte)) {
enum dToIdlType = "byte";
} else static if (is(FullyUnqual!T == double)) {
enum dToIdlType = "double";
} else static if (is(FullyUnqual!T == short)) {
enum dToIdlType = "i16";
} else static if (is(FullyUnqual!T == int)) {
enum dToIdlType = "i32";
} else static if (is(FullyUnqual!T == long)) {
enum dToIdlType = "i64";
} else static if (is(FullyUnqual!T : string)) {
enum dToIdlType = "string";
} else static if (is(FullyUnqual!T _ : U[], U)) {
enum dToIdlType = "list<" ~ dToIdlType!U ~ ">";
} else static if (is(FullyUnqual!T _ : V[K], K, V)) {
enum dToIdlType = "map<" ~ dToIdlType!K ~ ", " ~ dToIdlType!V ~ ">";
} else static if (is(FullyUnqual!T _ : HashSet!E, E)) {
enum dToIdlType = "set<" ~ dToIdlType!E ~ ">";
} else static if (is(FullyUnqual!T == struct) || is(FullyUnqual!T == enum) ||
is(FullyUnqual!T : TException)
) {
enum dToIdlType = FullyUnqual!(T).stringof;
} else {
static assert(false, "Cannot represent type in Thrift: " ~ T.stringof);
}
}
string dToIdlReq(TReq req) {
switch (req) {
case TReq.REQUIRED: return " required";
case TReq.OPTIONAL: return " optional";
default: return "";
}
}
string dToIdlConst(T)(T value) {
static if (is(FullyUnqual!T == bool)) {
return value ? "1" : "0";
} else static if (is(FullyUnqual!T == byte) ||
is(FullyUnqual!T == short) || is(FullyUnqual!T == int) ||
is(FullyUnqual!T == long)
) {
return to!string(value);
} else static if (is(FullyUnqual!T : string)) {
return `"` ~ to!string(value) ~ `"`;
} else static if (is(FullyUnqual!T == double)) {
return ctfeToString(value);
} else static if (is(FullyUnqual!T _ : U[], U) ||
is(FullyUnqual!T _ : HashSet!E, E)
) {
string result = "[";
foreach (e; value) {
result ~= dToIdlConst(e) ~ ", ";
}
result ~= "]";
return result;
} else static if (is(FullyUnqual!T _ : V[K], K, V)) {
string result = "{";
foreach (key, val; value) {
result ~= dToIdlConst(key) ~ ": " ~ dToIdlConst(val) ~ ", ";
}
result ~= "}";
return result;
} else static if (is(FullyUnqual!T == enum)) {
import std.conv;
import std.traits;
return to!string(cast(OriginalType!T)value);
} else static if (is(FullyUnqual!T == struct) ||
is(FullyUnqual!T : TException)
) {
string result = "{";
foreach (name; __traits(derivedMembers, T)) {
static if (memberReq!(T, name) != TReq.IGNORE) {
result ~= name ~ ": " ~ dToIdlConst(mixin("value." ~ name)) ~ ", ";
}
}
result ~= "}";
return result;
} else {
static assert(false, "Cannot represent type in Thrift: " ~ T.stringof);
}
}
}
version (unittest) {
enum Foo {
a = 1,
b = 10,
c = 5
}
static assert(enumIdlString!Foo ==
`enum Foo {
a = 1,
b = 10,
c = 5,
}
`);
}
version (unittest) {
struct WithoutMeta {
string a;
int b;
}
struct WithDefaults {
string a = "asdf";
double b = 3.1415;
WithoutMeta c;
mixin TStructHelpers!([
TFieldMeta("c", 1, TReq.init, `WithoutMeta("foo", 3)`)
]);
}
// These are from DebugProtoTest.thrift.
struct OneOfEach {
bool im_true;
bool im_false;
byte a_bite;
short integer16;
int integer32;
long integer64;
double double_precision;
string some_characters;
string zomg_unicode;
bool what_who;
string base64;
byte[] byte_list;
short[] i16_list;
long[] i64_list;
mixin TStructHelpers!([
TFieldMeta(`im_true`, 1),
TFieldMeta(`im_false`, 2),
TFieldMeta(`a_bite`, 3, TReq.OPT_IN_REQ_OUT, q{cast(byte)127}),
TFieldMeta(`integer16`, 4, TReq.OPT_IN_REQ_OUT, q{cast(short)32767}),
TFieldMeta(`integer32`, 5),
TFieldMeta(`integer64`, 6, TReq.OPT_IN_REQ_OUT, q{10000000000L}),
TFieldMeta(`double_precision`, 7),
TFieldMeta(`some_characters`, 8),
TFieldMeta(`zomg_unicode`, 9),
TFieldMeta(`what_who`, 10),
TFieldMeta(`base64`, 11),
TFieldMeta(`byte_list`, 12, TReq.OPT_IN_REQ_OUT, q{{
byte[] v;
v ~= cast(byte)1;
v ~= cast(byte)2;
v ~= cast(byte)3;
return v;
}()}),
TFieldMeta(`i16_list`, 13, TReq.OPT_IN_REQ_OUT, q{{
short[] v;
v ~= cast(short)1;
v ~= cast(short)2;
v ~= cast(short)3;
return v;
}()}),
TFieldMeta(`i64_list`, 14, TReq.OPT_IN_REQ_OUT, q{{
long[] v;
v ~= 1L;
v ~= 2L;
v ~= 3L;
return v;
}()})
]);
}
struct Bonk {
int type;
string message;
mixin TStructHelpers!([
TFieldMeta(`type`, 1),
TFieldMeta(`message`, 2)
]);
}
struct HolyMoley {
OneOfEach[] big;
HashSet!(string[]) contain;
Bonk[][string] bonks;
mixin TStructHelpers!([
TFieldMeta(`big`, 1),
TFieldMeta(`contain`, 2),
TFieldMeta(`bonks`, 3)
]);
}
static assert(structIdlString!WithoutMeta ==
`struct WithoutMeta {
-1: string a,
-2: i32 b,
}
`);
import std.algorithm;
static assert(structIdlString!WithDefaults.startsWith(
`struct WithDefaults {
-1: string a = "asdf",
-2: double b = 3.141`));
static assert(structIdlString!WithDefaults.endsWith(
`1: WithoutMeta c = {a: "foo", b: 3, },
}
`));
static assert(structIdlString!OneOfEach ==
`struct OneOfEach {
1: bool im_true,
2: bool im_false,
3: byte a_bite = 127,
4: i16 integer16 = 32767,
5: i32 integer32,
6: i64 integer64 = 10000000000,
7: double double_precision,
8: string some_characters,
9: string zomg_unicode,
10: bool what_who,
11: string base64,
12: list<byte> byte_list = [1, 2, 3, ],
13: list<i16> i16_list = [1, 2, 3, ],
14: list<i64> i64_list = [1, 2, 3, ],
}
`);
static assert(structIdlString!Bonk ==
`struct Bonk {
1: i32 type,
2: string message,
}
`);
static assert(structIdlString!HolyMoley ==
`struct HolyMoley {
1: list<OneOfEach> big,
2: set<list<string>> contain,
3: map<string, list<Bonk>> bonks,
}
`);
}
version (unittest) {
class ExceptionWithAMap : TException {
string blah;
string[string] map_field;
mixin TStructHelpers!([
TFieldMeta(`blah`, 1),
TFieldMeta(`map_field`, 2)
]);
}
interface Srv {
void voidMethod();
int primitiveMethod();
OneOfEach structMethod();
void methodWithDefaultArgs(int something);
void onewayMethod();
void exceptionMethod();
alias .ExceptionWithAMap ExceptionWithAMap;
enum methodMeta = [
TMethodMeta(`methodWithDefaultArgs`,
[TParamMeta(`something`, 1, q{2})]
),
TMethodMeta(`onewayMethod`,
[],
[],
TMethodType.ONEWAY
),
TMethodMeta(`exceptionMethod`,
[],
[
TExceptionMeta("a", 1, "ExceptionWithAMap"),
TExceptionMeta("b", 2, "ExceptionWithAMap")
]
)
];
}
interface ChildSrv : Srv {
int childMethod(int arg);
}
static assert(idlString!ChildSrv ==
`exception ExceptionWithAMap {
1: string blah,
2: map<string, string> map_field,
}
struct OneOfEach {
1: bool im_true,
2: bool im_false,
3: byte a_bite = 127,
4: i16 integer16 = 32767,
5: i32 integer32,
6: i64 integer64 = 10000000000,
7: double double_precision,
8: string some_characters,
9: string zomg_unicode,
10: bool what_who,
11: string base64,
12: list<byte> byte_list = [1, 2, 3, ],
13: list<i16> i16_list = [1, 2, 3, ],
14: list<i64> i64_list = [1, 2, 3, ],
}
service Srv {
void voidMethod(),
i32 primitiveMethod(),
OneOfEach structMethod(),
void methodWithDefaultArgs(1: i32 something = 2, ),
oneway void onewayMethod(),
void exceptionMethod() throws (1: ExceptionWithAMap a, 2: ExceptionWithAMap b, ),
}
service ChildSrv extends Srv {
i32 childMethod(-1: i32 param1, ),
}
`);
}

View file

@ -0,0 +1,497 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.codegen.processor;
import std.algorithm : find;
import std.array : empty, front;
import std.conv : to;
import std.traits : ParameterTypeTuple, ReturnType, Unqual;
import std.typetuple : allSatisfy, TypeTuple;
import std.variant : Variant;
import thrift.base;
import thrift.codegen.base;
import thrift.internal.codegen;
import thrift.internal.ctfe;
import thrift.protocol.base;
import thrift.protocol.processor;
/**
* Service processor for Interface, which implements TProcessor by
* synchronously forwarding requests for the service methods to a handler
* implementing Interface.
*
* The generated class implements TProcessor and additionally allows a
* TProcessorEventHandler to be specified via the public eventHandler property.
* The constructor takes a single argument of type Interface, which is the
* handler to forward the requests to:
* ---
* this(Interface iface);
* TProcessorEventHandler eventHandler;
* ---
*
* If Interface is derived from another service BaseInterface, this class is
* also derived from TServiceProcessor!BaseInterface.
*
* The optional Protocols template tuple parameter can be used to specify
* one or more TProtocol implementations to specifically generate code for. If
* the actual types of the protocols passed to process() at runtime match one
* of the items from the list, the optimized code paths are taken, otherwise,
* a generic TProtocol version is used as fallback. For cases where the input
* and output protocols differ, TProtocolPair!(InputProtocol, OutputProtocol)
* can be used in the Protocols list:
* ---
* interface FooService { void foo(); }
* class FooImpl { override void foo {} }
*
* // Provides fast path if TBinaryProtocol!TBufferedTransport is used for
* // both input and output:
* alias TServiceProcessor!(FooService, TBinaryProtocol!TBufferedTransport)
* BinaryProcessor;
*
* auto proc = new BinaryProcessor(new FooImpl());
*
* // Low overhead.
* proc.process(tBinaryProtocol(tBufferTransport(someSocket)));
*
* // Not in the specialization list higher overhead.
* proc.process(tBinaryProtocol(tFramedTransport(someSocket)));
*
* // Same as above, but optimized for the Compact protocol backed by a
* // TPipedTransport for input and a TBufferedTransport for output.
* alias TServiceProcessor!(FooService, TProtocolPair!(
* TCompactProtocol!TPipedTransport, TCompactProtocol!TBufferedTransport)
* ) MixedProcessor;
* ---
*/
template TServiceProcessor(Interface, Protocols...) if (
isService!Interface && allSatisfy!(isTProtocolOrPair, Protocols)
) {
mixin({
static if (is(Interface BaseInterfaces == super) && BaseInterfaces.length > 0) {
static assert(BaseInterfaces.length == 1,
"Services cannot be derived from more than one parent.");
string code = "class TServiceProcessor : " ~
"TServiceProcessor!(BaseService!Interface) {\n";
code ~= "private Interface iface_;\n";
string constructorCode = "this(Interface iface) {\n";
constructorCode ~= "super(iface);\n";
constructorCode ~= "iface_ = iface;\n";
} else {
string code = "class TServiceProcessor : TProcessor {";
code ~= q{
override bool process(TProtocol iprot, TProtocol oprot,
Variant context = Variant()
) {
auto msg = iprot.readMessageBegin();
void writeException(TApplicationException e) {
oprot.writeMessageBegin(TMessage(msg.name, TMessageType.EXCEPTION,
msg.seqid));
e.write(oprot);
oprot.writeMessageEnd();
oprot.transport.writeEnd();
oprot.transport.flush();
}
if (msg.type != TMessageType.CALL && msg.type != TMessageType.ONEWAY) {
skip(iprot, TType.STRUCT);
iprot.readMessageEnd();
iprot.transport.readEnd();
writeException(new TApplicationException(
TApplicationException.Type.INVALID_MESSAGE_TYPE));
return false;
}
auto dg = msg.name in processMap_;
if (!dg) {
skip(iprot, TType.STRUCT);
iprot.readMessageEnd();
iprot.transport.readEnd();
writeException(new TApplicationException("Invalid method name: '" ~
msg.name ~ "'.", TApplicationException.Type.INVALID_MESSAGE_TYPE));
return false;
}
(*dg)(msg.seqid, iprot, oprot, context);
return true;
}
TProcessorEventHandler eventHandler;
alias void delegate(int, TProtocol, TProtocol, Variant) ProcessFunc;
protected ProcessFunc[string] processMap_;
private Interface iface_;
};
string constructorCode = "this(Interface iface) {\n";
constructorCode ~= "iface_ = iface;\n";
}
// Generate the handling code for each method, consisting of the dispatch
// function, registering it in the constructor, and the actual templated
// handler function.
foreach (methodName;
FilterMethodNames!(Interface, __traits(derivedMembers, Interface))
) {
// Register the processing function in the constructor.
immutable procFuncName = "process_" ~ methodName;
immutable dispatchFuncName = procFuncName ~ "_protocolDispatch";
constructorCode ~= "processMap_[`" ~ methodName ~ "`] = &" ~
dispatchFuncName ~ ";\n";
bool methodMetaFound;
TMethodMeta methodMeta;
static if (is(typeof(Interface.methodMeta) : TMethodMeta[])) {
enum meta = find!`a.name == b`(Interface.methodMeta, methodName);
if (!meta.empty) {
methodMetaFound = true;
methodMeta = meta.front;
}
}
// The dispatch function to call the specialized handler functions. We
// test the protocols if they can be converted to one of the passed
// protocol types, and if not, fall back to the generic TProtocol
// version of the processing function.
code ~= "void " ~ dispatchFuncName ~
"(int seqid, TProtocol iprot, TProtocol oprot, Variant context) {\n";
code ~= "foreach (Protocol; TypeTuple!(Protocols, TProtocol)) {\n";
code ~= q{
static if (is(Protocol _ : TProtocolPair!(I, O), I, O)) {
alias I IProt;
alias O OProt;
} else {
alias Protocol IProt;
alias Protocol OProt;
}
auto castedIProt = cast(IProt)iprot;
auto castedOProt = cast(OProt)oprot;
};
code ~= "if (castedIProt && castedOProt) {\n";
code ~= procFuncName ~
"!(IProt, OProt)(seqid, castedIProt, castedOProt, context);\n";
code ~= "return;\n";
code ~= "}\n";
code ~= "}\n";
code ~= "throw new TException(`Internal error: Null iprot/oprot " ~
"passed to processor protocol dispatch function.`);\n";
code ~= "}\n";
// The actual handler function, templated on the input and output
// protocol types.
code ~= "void " ~ procFuncName ~ "(IProt, OProt)(int seqid, IProt " ~
"iprot, OProt oprot, Variant connectionContext) " ~
"if (isTProtocol!IProt && isTProtocol!OProt) {\n";
code ~= "TArgsStruct!(Interface, `" ~ methodName ~ "`) args;\n";
// Store the (qualified) method name in a manifest constant to avoid
// having to litter the code below with lots of string manipulation.
code ~= "enum methodName = `" ~ methodName ~ "`;\n";
code ~= q{
enum qName = Interface.stringof ~ "." ~ methodName;
Variant callContext;
if (eventHandler) {
callContext = eventHandler.createContext(qName, connectionContext);
}
scope (exit) {
if (eventHandler) {
eventHandler.deleteContext(callContext, qName);
}
}
if (eventHandler) eventHandler.preRead(callContext, qName);
args.read(iprot);
iprot.readMessageEnd();
iprot.transport.readEnd();
if (eventHandler) eventHandler.postRead(callContext, qName);
};
code ~= "TResultStruct!(Interface, `" ~ methodName ~ "`) result;\n";
code ~= "try {\n";
// Generate the parameter list to pass to the called iface function.
string[] paramList;
foreach (i, _; ParameterTypeTuple!(mixin("Interface." ~ methodName))) {
string paramName;
if (methodMetaFound && i < methodMeta.params.length) {
paramName = methodMeta.params[i].name;
} else {
paramName = "param" ~ to!string(i + 1);
}
paramList ~= "args." ~ paramName;
}
immutable call = "iface_." ~ methodName ~ "(" ~ ctfeJoin(paramList) ~ ")";
if (is(ReturnType!(mixin("Interface." ~ methodName)) == void)) {
code ~= call ~ ";\n";
} else {
code ~= "result.set!`success`(" ~ call ~ ");\n";
}
// If this is not a oneway method, generate the receiving code.
if (!methodMetaFound || methodMeta.type != TMethodType.ONEWAY) {
if (methodMetaFound) {
foreach (e; methodMeta.exceptions) {
code ~= "} catch (Interface." ~ e.type ~ " " ~ e.name ~ ") {\n";
code ~= "result.set!`" ~ e.name ~ "`(" ~ e.name ~ ");\n";
}
}
code ~= "}\n";
code ~= q{
catch (Exception e) {
if (eventHandler) {
eventHandler.handlerError(callContext, qName, e);
}
auto x = new TApplicationException(to!string(e));
oprot.writeMessageBegin(
TMessage(methodName, TMessageType.EXCEPTION, seqid));
x.write(oprot);
oprot.writeMessageEnd();
oprot.transport.writeEnd();
oprot.transport.flush();
return;
}
if (eventHandler) eventHandler.preWrite(callContext, qName);
oprot.writeMessageBegin(TMessage(methodName,
TMessageType.REPLY, seqid));
result.write(oprot);
oprot.writeMessageEnd();
oprot.transport.writeEnd();
oprot.transport.flush();
if (eventHandler) eventHandler.postWrite(callContext, qName);
};
} else {
// For oneway methods, we obviously cannot notify the client of any
// exceptions, just call the event handler if one is set.
code ~= "}\n";
code ~= q{
catch (Exception e) {
if (eventHandler) {
eventHandler.handlerError(callContext, qName, e);
}
return;
}
if (eventHandler) eventHandler.onewayComplete(callContext, qName);
};
}
code ~= "}\n";
}
code ~= constructorCode ~ "}\n";
code ~= "}\n";
return code;
}());
}
/**
* A struct representing the arguments of a Thrift method call.
*
* There should usually be no reason to use this directly without the help of
* TServiceProcessor, but it is documented publicly to help debugging in case
* of CTFE errors.
*
* Consider this example:
* ---
* interface Foo {
* int bar(string a, bool b);
*
* enum methodMeta = [
* TMethodMeta("bar", [TParamMeta("a", 1), TParamMeta("b", 2)])
* ];
* }
*
* alias TArgsStruct!(Foo, "bar") FooBarArgs;
* ---
*
* The definition of FooBarArgs is equivalent to:
* ---
* struct FooBarArgs {
* string a;
* bool b;
*
* mixin TStructHelpers!([TFieldMeta("a", 1, TReq.OPT_IN_REQ_OUT),
* TFieldMeta("b", 2, TReq.OPT_IN_REQ_OUT)]);
* }
* ---
*
* If the TVerboseCodegen version is defined, a warning message is issued at
* compilation if no TMethodMeta for Interface.methodName is found.
*/
template TArgsStruct(Interface, string methodName) {
static assert(is(typeof(mixin("Interface." ~ methodName))),
"Could not find method '" ~ methodName ~ "' in '" ~ Interface.stringof ~ "'.");
mixin({
bool methodMetaFound;
TMethodMeta methodMeta;
static if (is(typeof(Interface.methodMeta) : TMethodMeta[])) {
auto meta = find!`a.name == b`(Interface.methodMeta, methodName);
if (!meta.empty) {
methodMetaFound = true;
methodMeta = meta.front;
}
}
string memberCode;
string[] fieldMetaCodes;
foreach (i, _; ParameterTypeTuple!(mixin("Interface." ~ methodName))) {
// If we have no meta information, just use param1, param2, etc. as
// field names, it shouldn't really matter anyway. 1-based »indexing«
// is used to match the common scheme in the Thrift world.
string memberId;
string memberName;
if (methodMetaFound && i < methodMeta.params.length) {
memberId = to!string(methodMeta.params[i].id);
memberName = methodMeta.params[i].name;
} else {
memberId = to!string(i + 1);
memberName = "param" ~ to!string(i + 1);
}
// Unqual!() is needed to generate mutable fields for ref const()
// struct parameters.
memberCode ~= "Unqual!(ParameterTypeTuple!(Interface." ~ methodName ~
")[" ~ to!string(i) ~ "])" ~ memberName ~ ";\n";
fieldMetaCodes ~= "TFieldMeta(`" ~ memberName ~ "`, " ~ memberId ~
", TReq.OPT_IN_REQ_OUT)";
}
string code = "struct TArgsStruct {\n";
code ~= memberCode;
version (TVerboseCodegen) {
if (!methodMetaFound &&
ParameterTypeTuple!(mixin("Interface." ~ methodName)).length > 0)
{
code ~= "pragma(msg, `[thrift.codegen.processor.TArgsStruct] Warning: No " ~
"meta information for method '" ~ methodName ~ "' in service '" ~
Interface.stringof ~ "' found.`);\n";
}
}
immutable fieldMetaCode =
fieldMetaCodes.empty ? "" : "[" ~ ctfeJoin(fieldMetaCodes) ~ "]";
code ~= "mixin TStructHelpers!(" ~ fieldMetaCode ~ ");\n";
code ~= "}\n";
return code;
}());
}
/**
* A struct representing the result of a Thrift method call.
*
* It contains a field called "success" for the return value of the function
* (with id 0), and additional fields for the exceptions declared for the
* method, if any.
*
* There should usually be no reason to use this directly without the help of
* TServiceProcessor, but it is documented publicly to help debugging in case
* of CTFE errors.
*
* Consider the following example:
* ---
* interface Foo {
* int bar(string a);
*
* alias .FooException FooException;
*
* enum methodMeta = [
* TMethodMeta("bar",
* [TParamMeta("a", 1)],
* [TExceptionMeta("fooe", 1, "FooException")]
* )
* ];
* }
* alias TResultStruct!(Foo, "bar") FooBarResult;
* ---
*
* The definition of FooBarResult is equivalent to:
* ---
* struct FooBarResult {
* int success;
* FooException fooe;
*
* mixin(TStructHelpers!([TFieldMeta("success", 0, TReq.OPTIONAL),
* TFieldMeta("fooe", 1, TReq.OPTIONAL)]));
* }
* ---
*
* If the TVerboseCodegen version is defined, a warning message is issued at
* compilation if no TMethodMeta for Interface.methodName is found.
*/
template TResultStruct(Interface, string methodName) {
static assert(is(typeof(mixin("Interface." ~ methodName))),
"Could not find method '" ~ methodName ~ "' in '" ~ Interface.stringof ~ "'.");
mixin({
string code = "struct TResultStruct {\n";
string[] fieldMetaCodes;
static if (!is(ReturnType!(mixin("Interface." ~ methodName)) == void)) {
code ~= "ReturnType!(Interface." ~ methodName ~ ") success;\n";
fieldMetaCodes ~= "TFieldMeta(`success`, 0, TReq.OPTIONAL)";
}
bool methodMetaFound;
static if (is(typeof(Interface.methodMeta) : TMethodMeta[])) {
auto meta = find!`a.name == b`(Interface.methodMeta, methodName);
if (!meta.empty) {
foreach (e; meta.front.exceptions) {
code ~= "Interface." ~ e.type ~ " " ~ e.name ~ ";\n";
fieldMetaCodes ~= "TFieldMeta(`" ~ e.name ~ "`, " ~ to!string(e.id) ~
", TReq.OPTIONAL)";
}
methodMetaFound = true;
}
}
version (TVerboseCodegen) {
if (!methodMetaFound &&
ParameterTypeTuple!(mixin("Interface." ~ methodName)).length > 0)
{
code ~= "pragma(msg, `[thrift.codegen.processor.TResultStruct] Warning: No " ~
"meta information for method '" ~ methodName ~ "' in service '" ~
Interface.stringof ~ "' found.`);\n";
}
}
immutable fieldMetaCode =
fieldMetaCodes.empty ? "" : "[" ~ ctfeJoin(fieldMetaCodes) ~ "]";
code ~= "mixin TStructHelpers!(" ~ fieldMetaCode ~ ");\n";
code ~= "}\n";
return code;
}());
}

View file

@ -0,0 +1,33 @@
Ddoc
<h2>Package overview</h2>
<dl>
<dt>$(D_CODE thrift.async)</dt>
<dd>Support infrastructure for handling client-side asynchronous operations using non-blocking I/O and coroutines.</dd>
<dt>$(D_CODE thrift.codegen)</dt>
<dd>
<p>Templates used for generating Thrift clients/processors from regular D struct and interface definitions.</p>
<p><strong>Note:</strong> Several artifacts in these modules have options for specifying the exact protocol types used. In this case, the amount of virtual calls can be greatly reduced and as a result, the code also can be optimized better. If performance is not a concern or the actual protocol type is not known at compile time, these parameters can just be left at their defaults.
</p>
</dd>
<dt>$(D_CODE thrift.internal)</dt>
<dd>Internal helper modules used by the Thrift library. This package is not part of the public API, and no stability guarantees are given whatsoever.</dd>
<dt>$(D_CODE thrift.protocol)</dt>
<dd>The Thrift protocol implemtations which specify how to pass messages over a TTransport.</dd>
<dt>$(D_CODE thrift.server)</dt>
<dd>Generic Thrift server implementations handling clients over a TTransport interface and forwarding requests to a TProcessor (which is in turn usually provided by thrift.codegen).</dd>
<dt>$(D_CODE thrift.transport)</dt>
<dd>The TTransport data source/sink interface used in the Thrift library and its imiplementations.</dd>
<dt>$(D_CODE thrift.util)</dt>
<dd>General-purpose utility modules not specific to Thrift, part of the public API.</dd>
</dl>
Macros:
TITLE = Thrift D Software Library

View file

@ -0,0 +1,55 @@
/**
* Contains a modified version of std.algorithm.remove that doesn't take an
* alias parameter to avoid DMD @@BUG6395@@.
*/
module thrift.internal.algorithm;
import std.algorithm : move;
import std.exception;
import std.functional;
import std.range;
import std.traits;
enum SwapStrategy
{
unstable,
semistable,
stable,
}
Range removeEqual(SwapStrategy s = SwapStrategy.stable, Range, E)(Range range, E e)
if (isBidirectionalRange!Range)
{
auto result = range;
static if (s != SwapStrategy.stable)
{
for (;!range.empty;)
{
if (range.front !is e)
{
range.popFront;
continue;
}
move(range.back, range.front);
range.popBack;
result.popBack;
}
}
else
{
auto tgt = range;
for (; !range.empty; range.popFront)
{
if (range.front is e)
{
// yank this guy
result.popBack;
continue;
}
// keep this guy
move(range.front, tgt.front);
tgt.popFront;
}
}
return result;
}

View file

@ -0,0 +1,451 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.internal.codegen;
import std.algorithm : canFind;
import std.traits : InterfacesTuple, isSomeFunction, isSomeString;
import std.typetuple : staticIndexOf, staticMap, NoDuplicates, TypeTuple;
import thrift.codegen.base;
/**
* Removes all type qualifiers from T.
*
* In contrast to std.traits.Unqual, FullyUnqual also removes qualifiers from
* array elements (e.g. immutable(byte[]) -> byte[], not immutable(byte)[]),
* excluding strings (string isn't reduced to char[]).
*/
template FullyUnqual(T) {
static if (is(T _ == const(U), U)) {
alias FullyUnqual!U FullyUnqual;
} else static if (is(T _ == immutable(U), U)) {
alias FullyUnqual!U FullyUnqual;
} else static if (is(T _ == shared(U), U)) {
alias FullyUnqual!U FullyUnqual;
} else static if (is(T _ == U[], U) && !isSomeString!T) {
alias FullyUnqual!(U)[] FullyUnqual;
} else static if (is(T _ == V[K], K, V)) {
alias FullyUnqual!(V)[FullyUnqual!K] FullyUnqual;
} else {
alias T FullyUnqual;
}
}
/**
* true if null can be assigned to the passed type, false if not.
*/
template isNullable(T) {
enum isNullable = __traits(compiles, { T t = null; });
}
template isStruct(T) {
enum isStruct = is(T == struct);
}
template isException(T) {
enum isException = is(T : Exception);
}
template isEnum(T) {
enum isEnum = is(T == enum);
}
/**
* Aliases itself to T.name.
*/
template GetMember(T, string name) {
mixin("alias T." ~ name ~ " GetMember;");
}
/**
* Aliases itself to typeof(symbol).
*/
template TypeOf(alias symbol) {
alias typeof(symbol) TypeOf;
}
/**
* Aliases itself to the type of the T member called name.
*/
alias Compose!(TypeOf, GetMember) MemberType;
/**
* Returns the field metadata array for T if any, or an empty array otherwise.
*/
template getFieldMeta(T) if (isStruct!T || isException!T) {
static if (is(typeof(T.fieldMeta) == TFieldMeta[])) {
enum getFieldMeta = T.fieldMeta;
} else {
enum TFieldMeta[] getFieldMeta = [];
}
}
/**
* Merges the field metadata array for D with the passed array.
*/
template mergeFieldMeta(T, alias fieldMetaData = cast(TFieldMeta[])null) {
// Note: We don't use getFieldMeta here to avoid bug if it is instantiated
// from TIsSetFlags, see comment there.
static if (is(typeof(T.fieldMeta) == TFieldMeta[])) {
enum mergeFieldMeta = T.fieldMeta ~ fieldMetaData;
} else {
enum TFieldMeta[] mergeFieldMeta = fieldMetaData;
}
}
/**
* Returns the field requirement level for T.name.
*/
template memberReq(T, string name, alias fieldMetaData = cast(TFieldMeta[])null) {
enum memberReq = memberReqImpl!(T, name, fieldMetaData).result;
}
private {
import std.algorithm : find;
// DMD @@BUG@@: Missing import leads to failing build without error
// message in unittest/debug/thrift/codegen/async_client.
import std.array : empty, front;
template memberReqImpl(T, string name, alias fieldMetaData) {
enum meta = find!`a.name == b`(mergeFieldMeta!(T, fieldMetaData), name);
static if (meta.empty || meta.front.req == TReq.AUTO) {
static if (isNullable!(MemberType!(T, name))) {
enum result = TReq.OPTIONAL;
} else {
enum result = TReq.REQUIRED;
}
} else {
enum result = meta.front.req;
}
}
}
template notIgnored(T, string name, alias fieldMetaData = cast(TFieldMeta[])null) {
enum notIgnored = memberReq!(T, name, fieldMetaData) != TReq.IGNORE;
}
/**
* Returns the method metadata array for T if any, or an empty array otherwise.
*/
template getMethodMeta(T) if (isService!T) {
static if (is(typeof(T.methodMeta) == TMethodMeta[])) {
enum getMethodMeta = T.methodMeta;
} else {
enum TMethodMeta[] getMethodMeta = [];
}
}
/**
* true if T.name is a member variable. Exceptions include methods, static
* members, artifacts like package aliases,
*/
template isValueMember(T, string name) {
static if (!is(MemberType!(T, name))) {
enum isValueMember = false;
} else static if (
is(MemberType!(T, name) == void) ||
isSomeFunction!(MemberType!(T, name)) ||
__traits(compiles, { return mixin("T." ~ name); }())
) {
enum isValueMember = false;
} else {
enum isValueMember = true;
}
}
/**
* Returns a tuple containing the names of the fields of T, not including
* inherited fields. If a member is marked as TReq.IGNORE, it is not included
* as well.
*/
template FieldNames(T, alias fieldMetaData = cast(TFieldMeta[])null) {
alias StaticFilter!(
All!(
doesNotReadMembers,
PApply!(isValueMember, T),
PApply!(notIgnored, T, PApplySkip, fieldMetaData)
),
__traits(derivedMembers, T)
) FieldNames;
}
/*
* true if the passed member name is not a method generated by the
* TStructHelpers template that in its implementations queries the struct
* members.
*
* Kludge used internally to break a cycle caused a DMD forward reference
* regression, see THRIFT-2130.
*/
enum doesNotReadMembers(string name) = !["opEquals", "thriftOpEqualsImpl",
"toString", "thriftToStringImpl"].canFind(name);
template derivedMembers(T) {
alias TypeTuple!(__traits(derivedMembers, T)) derivedMembers;
}
template AllMemberMethodNames(T) if (isService!T) {
alias NoDuplicates!(
FilterMethodNames!(
T,
staticMap!(
derivedMembers,
TypeTuple!(T, InterfacesTuple!T)
)
)
) AllMemberMethodNames;
}
template FilterMethodNames(T, MemberNames...) {
alias StaticFilter!(
CompilesAndTrue!(
Compose!(isSomeFunction, TypeOf, PApply!(GetMember, T))
),
MemberNames
) FilterMethodNames;
}
/**
* Returns a type tuple containing only the elements of T for which the
* eponymous template predicate pred is true.
*
* Example:
* ---
* alias StaticFilter!(isIntegral, int, string, long, float[]) Filtered;
* static assert(is(Filtered == TypeTuple!(int, long)));
* ---
*/
template StaticFilter(alias pred, T...) {
static if (T.length == 0) {
alias TypeTuple!() StaticFilter;
} else static if (pred!(T[0])) {
alias TypeTuple!(T[0], StaticFilter!(pred, T[1 .. $])) StaticFilter;
} else {
alias StaticFilter!(pred, T[1 .. $]) StaticFilter;
}
}
/**
* Binds the first n arguments of a template to a particular value (where n is
* the number of arguments passed to PApply).
*
* The passed arguments are always applied starting from the left. However,
* the special PApplySkip marker template can be used to indicate that an
* argument should be skipped, so that e.g. the first and third argument
* to a template can be fixed, but the second and remaining arguments would
* still be left undefined.
*
* Skipping a number of parameters, but not providing enough arguments to
* assign all of them during instantiation of the resulting template is an
* error.
*
* Example:
* ---
* struct Foo(T, U, V) {}
* alias PApply!(Foo, int, long) PartialFoo;
* static assert(is(PartialFoo!float == Foo!(int, long, float)));
*
* alias PApply!(Test, int, PApplySkip, float) SkippedTest;
* static assert(is(SkippedTest!long == Test!(int, long, float)));
* ---
*/
template PApply(alias Target, T...) {
template PApply(U...) {
alias Target!(PApplyMergeArgs!(ConfinedTuple!T, U).Result) PApply;
}
}
/// Ditto.
template PApplySkip() {}
private template PApplyMergeArgs(alias Preset, Args...) {
static if (Preset.length == 0) {
alias Args Result;
} else {
enum nextSkip = staticIndexOf!(PApplySkip, Preset.Tuple);
static if (nextSkip == -1) {
alias TypeTuple!(Preset.Tuple, Args) Result;
} else static if (Args.length == 0) {
// Have to use a static if clause instead of putting the condition
// directly into the assert to avoid DMD trying to access Args[0]
// nevertheless below.
static assert(false,
"PArgsSkip encountered, but no argument left to bind.");
} else {
alias TypeTuple!(
Preset.Tuple[0 .. nextSkip],
Args[0],
PApplyMergeArgs!(
ConfinedTuple!(Preset.Tuple[nextSkip + 1 .. $]),
Args[1 .. $]
).Result
) Result;
}
}
}
unittest {
struct Test(T, U, V) {}
alias PApply!(Test, int, long) PartialTest;
static assert(is(PartialTest!float == Test!(int, long, float)));
alias PApply!(Test, int, PApplySkip, float) SkippedTest;
static assert(is(SkippedTest!long == Test!(int, long, float)));
alias PApply!(Test, int, PApplySkip, PApplySkip) TwoSkipped;
static assert(!__traits(compiles, TwoSkipped!long));
}
/**
* Composes a number of templates. The result is a template equivalent to
* all the passed templates evaluated from right to left, akin to the
* mathematical function composition notation: Instantiating Compose!(A, B, C)
* is the same as instantiating A!(B!(C!())).
*
* This is especially useful for creating a template to use with staticMap/
* StaticFilter, as demonstrated below.
*
* Example:
* ---
* template AllMethodNames(T) {
* alias StaticFilter!(
* CompilesAndTrue!(
* Compose!(isSomeFunction, TypeOf, PApply!(GetMember, T))
* ),
* __traits(allMembers, T)
* ) AllMethodNames;
* }
*
* pragma(msg, AllMethodNames!Object);
* ---
*/
template Compose(T...) {
static if (T.length == 0) {
template Compose(U...) {
alias U Compose;
}
} else {
template Compose(U...) {
alias Instantiate!(T[0], Instantiate!(.Compose!(T[1 .. $]), U)) Compose;
}
}
}
/**
* Instantiates the given template with the given list of parameters.
*
* Used to work around syntactic limiations of D with regard to instantiating
* a template from a type tuple (e.g. T[0]!(...) is not valid) or a template
* returning another template (e.g. Foo!(Bar)!(Baz) is not allowed).
*/
template Instantiate(alias Template, Params...) {
alias Template!Params Instantiate;
}
/**
* Combines several template predicates using logical AND, i.e. instantiating
* All!(a, b, c) with parameters P for some templates a, b, c is equivalent to
* a!P && b!P && c!P.
*
* The templates are evaluated from left to right, aborting evaluation in a
* shurt-cut manner if a false result is encountered, in which case the latter
* instantiations do not need to compile.
*/
template All(T...) {
static if (T.length == 0) {
template All(U...) {
enum All = true;
}
} else {
template All(U...) {
static if (Instantiate!(T[0], U)) {
alias Instantiate!(.All!(T[1 .. $]), U) All;
} else {
enum All = false;
}
}
}
}
/**
* Combines several template predicates using logical OR, i.e. instantiating
* Any!(a, b, c) with parameters P for some templates a, b, c is equivalent to
* a!P || b!P || c!P.
*
* The templates are evaluated from left to right, aborting evaluation in a
* shurt-cut manner if a true result is encountered, in which case the latter
* instantiations do not need to compile.
*/
template Any(T...) {
static if (T.length == 0) {
template Any(U...) {
enum Any = false;
}
} else {
template Any(U...) {
static if (Instantiate!(T[0], U)) {
enum Any = true;
} else {
alias Instantiate!(.Any!(T[1 .. $]), U) Any;
}
}
}
}
template ConfinedTuple(T...) {
alias T Tuple;
enum length = T.length;
}
/*
* foreach (Item; Items) {
* List = Operator!(Item, List);
* }
* where Items is a ConfinedTuple and List is a type tuple.
*/
template ForAllWithList(alias Items, alias Operator, List...) if (
is(typeof(Items.length) : size_t)
){
static if (Items.length == 0) {
alias List ForAllWithList;
} else {
alias .ForAllWithList!(
ConfinedTuple!(Items.Tuple[1 .. $]),
Operator,
Operator!(Items.Tuple[0], List)
) ForAllWithList;
}
}
/**
* Wraps the passed template predicate so it returns true if it compiles and
* evaluates to true, false it it doesn't compile or evaluates to false.
*/
template CompilesAndTrue(alias T) {
template CompilesAndTrue(U...) {
static if (is(typeof(T!U) : bool)) {
enum bool CompilesAndTrue = T!U;
} else {
enum bool CompilesAndTrue = false;
}
}
}

View file

@ -0,0 +1,98 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.internal.ctfe;
import std.conv : to;
import std.traits;
/*
* Simple eager join() for strings, std.algorithm.join isn't CTFEable yet.
*/
string ctfeJoin(string[] strings, string separator = ", ") {
string result;
if (strings.length > 0) {
result ~= strings[0];
foreach (s; strings[1..$]) {
result ~= separator ~ s;
}
}
return result;
}
/*
* A very primitive to!string() implementation for floating point numbers that
* is evaluatable at compile time.
*
* There is a wealth of problems associated with the algorithm used (e.g. 5.0
* prints as 4.999, incorrect rounding, etc.), but a better alternative should
* be included with the D standard library instead of implementing it here.
*/
string ctfeToString(T)(T val) if (isFloatingPoint!T) {
if (val is T.nan) return "nan";
if (val is T.infinity) return "inf";
if (val is -T.infinity) return "-inf";
if (val is 0.0) return "0";
if (val is -0.0) return "-0";
auto b = val;
string result;
if (b < 0) {
result ~= '-';
b *= -1;
}
short magnitude;
while (b >= 10) {
++magnitude;
b /= 10;
}
while (b < 1) {
--magnitude;
b *= 10;
}
foreach (i; 0 .. T.dig) {
if (i == 1) result ~= '.';
auto first = cast(ubyte)b;
result ~= to!string(first);
b -= first;
import std.math;
if (b < pow(10.0, i - T.dig)) break;
b *= 10;
}
if (magnitude != 0) result ~= "e" ~ to!string(magnitude);
return result;
}
unittest {
import std.algorithm;
static assert(ctfeToString(double.infinity) == "inf");
static assert(ctfeToString(-double.infinity) == "-inf");
static assert(ctfeToString(double.nan) == "nan");
static assert(ctfeToString(0.0) == "0");
static assert(ctfeToString(-0.0) == "-0");
static assert(ctfeToString(2.5) == "2.5");
static assert(ctfeToString(3.1415).startsWith("3.141"));
static assert(ctfeToString(2e-200) == "2e-200");
}

View file

@ -0,0 +1,75 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*
* Simple helpers for handling typical byte order-related issues.
*/
module thrift.internal.endian;
import core.bitop : bswap;
import std.traits : isIntegral;
union IntBuf(T) {
ubyte[T.sizeof] bytes;
T value;
}
T byteSwap(T)(T t) pure nothrow @trusted if (isIntegral!T) {
static if (T.sizeof == 2) {
return cast(T)((t & 0xff) << 8) | cast(T)((t & 0xff00) >> 8);
} else static if (T.sizeof == 4) {
return cast(T)bswap(cast(uint)t);
} else static if (T.sizeof == 8) {
return cast(T)byteSwap(cast(uint)(t & 0xffffffff)) << 32 |
cast(T)bswap(cast(uint)(t >> 32));
} else static assert(false, "Type of size " ~ to!string(T.sizeof) ~ " not supported.");
}
T doNothing(T)(T val) { return val; }
version (BigEndian) {
alias doNothing hostToNet;
alias doNothing netToHost;
alias byteSwap hostToLe;
alias byteSwap leToHost;
} else {
alias byteSwap hostToNet;
alias byteSwap netToHost;
alias doNothing hostToLe;
alias doNothing leToHost;
}
unittest {
import std.exception;
IntBuf!short s;
s.bytes = [1, 2];
s.value = byteSwap(s.value);
enforce(s.bytes == [2, 1]);
IntBuf!int i;
i.bytes = [1, 2, 3, 4];
i.value = byteSwap(i.value);
enforce(i.bytes == [4, 3, 2, 1]);
IntBuf!long l;
l.bytes = [1, 2, 3, 4, 5, 6, 7, 8];
l.value = byteSwap(l.value);
enforce(l.bytes == [8, 7, 6, 5, 4, 3, 2, 1]);
}

View file

@ -0,0 +1,431 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.internal.resource_pool;
import core.time : Duration, dur, TickDuration;
import std.algorithm : minPos, reduce, remove;
import std.array : array, empty;
import std.exception : enforce;
import std.conv : to;
import std.random : randomCover, rndGen;
import std.range : zip;
import thrift.internal.algorithm : removeEqual;
/**
* A pool of resources, which can be iterated over, and where resources that
* have failed too often can be temporarily disabled.
*
* This class is oblivious to the actual resource type managed.
*/
final class TResourcePool(Resource) {
/**
* Constructs a new instance.
*
* Params:
* resources = The initial members of the pool.
*/
this(Resource[] resources) {
resources_ = resources;
}
/**
* Adds a resource to the pool.
*/
void add(Resource resource) {
resources_ ~= resource;
}
/**
* Removes a resource from the pool.
*
* Returns: Whether the resource could be found in the pool.
*/
bool remove(Resource resource) {
auto oldLength = resources_.length;
resources_ = removeEqual(resources_, resource);
return resources_.length < oldLength;
}
/**
* Returns an »enriched« input range to iterate over the pool members.
*/
static struct Range {
/**
* Whether the range is empty.
*
* This is the case if all members of the pool have been popped (or skipped
* because they were disabled) and TResourcePool.cycle is false, or there
* is no element to return in cycle mode because all have been temporarily
* disabled.
*/
bool empty() @property {
// If no resources are in the pool, the range will never become non-empty.
if (resources_.empty) return true;
// If we already got the next resource in the cache, it doesn't matter
// whether there are more.
if (cached_) return false;
size_t examineCount;
if (parent_.cycle) {
// We want to check all the resources, but not iterate more than once
// to avoid spinning in a loop if nothing is available.
examineCount = resources_.length;
} else {
// When not in cycle mode, we just iterate the list exactly once. If all
// items have been consumed, the interval below is empty.
examineCount = resources_.length - nextIndex_;
}
foreach (i; 0 .. examineCount) {
auto r = resources_[(nextIndex_ + i) % resources_.length];
auto fi = r in parent_.faultInfos_;
if (fi && fi.resetTime != fi.resetTime.init) {
if (fi.resetTime < parent_.getCurrentTick_()) {
// The timeout expired, remove the resource from the list and go
// ahead trying it.
parent_.faultInfos_.remove(r);
} else {
// The timeout didn't expire yet, try the next resource.
continue;
}
}
cache_ = r;
cached_ = true;
nextIndex_ = nextIndex_ + i + 1;
return false;
}
// If we get here, all resources are currently inactive or the non-cycle
// pool has been exhausted, so there is nothing we can do.
nextIndex_ = nextIndex_ + examineCount;
return true;
}
/**
* Returns the first resource in the range.
*/
Resource front() @property {
enforce(!empty);
return cache_;
}
/**
* Removes the first resource from the range.
*
* Usually, this is combined with a call to TResourcePool.recordSuccess()
* or recordFault().
*/
void popFront() {
enforce(!empty);
cached_ = false;
}
/**
* Returns whether the range will become non-empty at some point in the
* future, and provides additional information when this will happen and
* what will be the next resource.
*
* Makes only sense to call on empty ranges.
*
* Params:
* next = The next resource that will become available.
* waitTime = The duration until that resource will become available.
*/
bool willBecomeNonempty(out Resource next, out Duration waitTime) {
// If no resources are in the pool, the range will never become non-empty.
if (resources_.empty) return false;
// If cycle mode is not enabled, a range never becomes non-empty after
// being empty once, because all the elements have already been
// used/skipped in order to become empty.
if (!parent_.cycle) return false;
auto fi = parent_.faultInfos_;
auto nextPair = minPos!"a[1].resetTime < b[1].resetTime"(
zip(fi.keys, fi.values)
).front;
next = nextPair[0];
waitTime = to!Duration(nextPair[1].resetTime - parent_.getCurrentTick_());
return true;
}
private:
this(TResourcePool parent, Resource[] resources) {
parent_ = parent;
resources_ = resources;
}
TResourcePool parent_;
/// All available resources. We keep a copy of it as to not get confused
/// when resources are added to/removed from the parent pool.
Resource[] resources_;
/// After we have determined the next element in empty(), we store it here.
Resource cache_;
/// Whether there is currently something in the cache.
bool cached_;
/// The index to start searching from at the next call to empty().
size_t nextIndex_;
}
/// Ditto
Range opSlice() {
auto res = resources_;
if (permute) {
res = array(randomCover(res, rndGen));
}
return Range(this, res);
}
/**
* Records a success for an operation on the given resource, cancelling a
* fault streak, if any.
*/
void recordSuccess(Resource resource) {
if (resource in faultInfos_) {
faultInfos_.remove(resource);
}
}
/**
* Records a fault for the given resource.
*
* If a resource fails consecutively for more than faultDisableCount times,
* it is temporarily disabled (no longer considered) until
* faultDisableDuration has passed.
*/
void recordFault(Resource resource) {
auto fi = resource in faultInfos_;
if (!fi) {
faultInfos_[resource] = FaultInfo();
fi = resource in faultInfos_;
}
++fi.count;
if (fi.count >= faultDisableCount) {
// If the resource has hit the fault count limit, disable it for
// specified duration.
fi.resetTime = getCurrentTick_() + cast(TickDuration)faultDisableDuration;
}
}
/**
* Whether to randomly permute the order of the resources in the pool when
* taking a range using opSlice().
*
* This can be used e.g. as a simple form of load balancing.
*/
bool permute = true;
/**
* Whether to keep iterating over the pool members after all have been
* returned/have failed once.
*/
bool cycle = false;
/**
* The number of consecutive faults after which a resource is disabled until
* faultDisableDuration has passed. Zero to never disable resources.
*
* Defaults to zero.
*/
ushort faultDisableCount = 0;
/**
* The duration for which a resource is no longer considered after it has
* failed too often.
*
* Defaults to one second.
*/
Duration faultDisableDuration = dur!"seconds"(1);
private:
Resource[] resources_;
FaultInfo[Resource] faultInfos_;
/// Function to get the current timestamp from some monotonic system clock.
///
/// This is overridable to be able to write timing-insensitive unit tests.
/// The extra indirection should not matter much performance-wise compared to
/// the actual system call, and by its very nature thisshould not be on a hot
/// path anyway.
typeof(&TickDuration.currSystemTick) getCurrentTick_ =
&TickDuration.currSystemTick;
}
private {
struct FaultInfo {
ushort count;
TickDuration resetTime;
}
}
unittest {
auto pool = new TResourcePool!Object([]);
enforce(pool[].empty);
Object dummyRes;
Duration dummyDur;
enforce(!pool[].willBecomeNonempty(dummyRes, dummyDur));
}
unittest {
import std.datetime;
import thrift.base;
auto a = new Object;
auto b = new Object;
auto c = new Object;
auto objs = [a, b, c];
auto pool = new TResourcePool!Object(objs);
pool.permute = false;
static Duration fakeClock;
pool.getCurrentTick_ = () => cast(TickDuration)fakeClock;
Object dummyRes = void;
Duration dummyDur = void;
{
auto r = pool[];
foreach (i, o; objs) {
enforce(!r.empty);
enforce(r.front == o);
r.popFront();
}
enforce(r.empty);
enforce(!r.willBecomeNonempty(dummyRes, dummyDur));
}
{
pool.faultDisableCount = 2;
enforce(pool[].front == a);
pool.recordFault(a);
enforce(pool[].front == a);
pool.recordSuccess(a);
enforce(pool[].front == a);
pool.recordFault(a);
enforce(pool[].front == a);
pool.recordFault(a);
auto r = pool[];
enforce(r.front == b);
r.popFront();
enforce(r.front == c);
r.popFront();
enforce(r.empty);
enforce(!r.willBecomeNonempty(dummyRes, dummyDur));
fakeClock += 2.seconds;
// Not in cycle mode, has to be still empty after the timeouts expired.
enforce(r.empty);
enforce(!r.willBecomeNonempty(dummyRes, dummyDur));
foreach (o; objs) pool.recordSuccess(o);
}
{
pool.faultDisableCount = 1;
pool.recordFault(a);
pool.recordFault(b);
pool.recordFault(c);
auto r = pool[];
enforce(r.empty);
enforce(!r.willBecomeNonempty(dummyRes, dummyDur));
foreach (o; objs) pool.recordSuccess(o);
}
pool.cycle = true;
{
auto r = pool[];
foreach (o; objs ~ objs) {
enforce(!r.empty);
enforce(r.front == o);
r.popFront();
}
}
{
pool.faultDisableCount = 2;
enforce(pool[].front == a);
pool.recordFault(a);
enforce(pool[].front == a);
pool.recordSuccess(a);
enforce(pool[].front == a);
pool.recordFault(a);
enforce(pool[].front == a);
pool.recordFault(a);
auto r = pool[];
enforce(r.front == b);
r.popFront();
enforce(r.front == c);
r.popFront();
enforce(r.front == b);
fakeClock += 2.seconds;
r.popFront();
enforce(r.front == c);
r.popFront();
enforce(r.front == a);
enforce(pool[].front == a);
foreach (o; objs) pool.recordSuccess(o);
}
{
pool.faultDisableCount = 1;
pool.recordFault(a);
fakeClock += 1.msecs;
pool.recordFault(b);
fakeClock += 1.msecs;
pool.recordFault(c);
auto r = pool[];
enforce(r.empty);
// Make sure willBecomeNonempty gets the order right.
enforce(r.willBecomeNonempty(dummyRes, dummyDur));
enforce(dummyRes == a);
enforce(dummyDur > Duration.zero);
foreach (o; objs) pool.recordSuccess(o);
}
}

View file

@ -0,0 +1,96 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/**
* Abstractions over OS-dependent socket functionality.
*/
module thrift.internal.socket;
import std.conv : to;
// FreeBSD and OS X return -1 and set ECONNRESET if socket was closed by
// the other side, we need to check for that before throwing an exception.
version (FreeBSD) {
enum connresetOnPeerShutdown = true;
} else version (OSX) {
enum connresetOnPeerShutdown = true;
} else {
enum connresetOnPeerShutdown = false;
}
version (Win32) {
import std.c.windows.winsock : WSAGetLastError, WSAEINTR, WSAEWOULDBLOCK;
import std.windows.syserror : sysErrorString;
// These are unfortunately not defined in std.c.windows.winsock, see
// http://msdn.microsoft.com/en-us/library/ms740668.aspx.
enum WSAECONNRESET = 10054;
enum WSAENOTCONN = 10057;
enum WSAETIMEDOUT = 10060;
} else {
import core.stdc.errno : errno, EAGAIN, ECONNRESET, EINPROGRESS, EINTR,
ENOTCONN, EPIPE;
import core.stdc.string : strerror;
}
/*
* CONNECT_INPROGRESS_ERRNO: set by connect() for non-blocking sockets if the
* connection could not be immediately established.
* INTERRUPTED_ERRNO: set when blocking system calls are interrupted by
* signals or similar.
* TIMEOUT_ERRNO: set when a socket timeout has been exceeded.
* WOULD_BLOCK_ERRNO: set when send/recv would block on non-blocking sockets.
*
* isSocetCloseErrno(errno): returns true if errno indicates that the socket
* is logically in closed state now.
*/
version (Win32) {
alias WSAGetLastError getSocketErrno;
enum CONNECT_INPROGRESS_ERRNO = WSAEWOULDBLOCK;
enum INTERRUPTED_ERRNO = WSAEINTR;
enum TIMEOUT_ERRNO = WSAETIMEDOUT;
enum WOULD_BLOCK_ERRNO = WSAEWOULDBLOCK;
bool isSocketCloseErrno(typeof(getSocketErrno()) errno) {
return (errno == WSAECONNRESET || errno == WSAENOTCONN);
}
} else {
alias errno getSocketErrno;
enum CONNECT_INPROGRESS_ERRNO = EINPROGRESS;
enum INTERRUPTED_ERRNO = EINTR;
enum WOULD_BLOCK_ERRNO = EAGAIN;
// TODO: The C++ TSocket implementation mentions that EAGAIN can also be
// set (undocumentedly) in out of resource conditions; it would be a good
// idea to contact the original authors of the C++ code for details and adapt
// the code accordingly.
enum TIMEOUT_ERRNO = EAGAIN;
bool isSocketCloseErrno(typeof(getSocketErrno()) errno) {
return (errno == EPIPE || errno == ECONNRESET || errno == ENOTCONN);
}
}
string socketErrnoString(uint errno) {
version (Win32) {
return sysErrorString(errno);
} else {
return to!string(strerror(errno));
}
}

View file

@ -0,0 +1,240 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.internal.ssl;
import core.memory : GC;
import core.stdc.config;
import core.stdc.errno : errno;
import core.stdc.string : strerror;
import deimos.openssl.err;
import deimos.openssl.ssl;
import deimos.openssl.x509v3;
import std.array : empty, appender;
import std.conv : to;
import std.socket : Address;
import thrift.transport.ssl;
/**
* Checks if the peer is authorized after the SSL handshake has been
* completed on the given conncetion and throws an TSSLException if not.
*
* Params:
* ssl = The SSL connection to check.
* accessManager = The access manager to check the peer againts.
* peerAddress = The (IP) address of the peer.
* hostName = The host name of the peer.
*/
void authorize(SSL* ssl, TAccessManager accessManager,
Address peerAddress, lazy string hostName
) {
alias TAccessManager.Decision Decision;
auto rc = SSL_get_verify_result(ssl);
if (rc != X509_V_OK) {
throw new TSSLException("SSL_get_verify_result(): " ~
to!string(X509_verify_cert_error_string(rc)));
}
auto cert = SSL_get_peer_certificate(ssl);
if (cert is null) {
// Certificate is not present.
if (SSL_get_verify_mode(ssl) & SSL_VERIFY_FAIL_IF_NO_PEER_CERT) {
throw new TSSLException(
"Authorize: Required certificate not present.");
}
// If we don't have an access manager set, we don't intend to authorize
// the client, so everything's fine.
if (accessManager) {
throw new TSSLException(
"Authorize: Certificate required for authorization.");
}
return;
}
if (accessManager is null) {
// No access manager set, can return immediately as the cert is valid
// and all peers are authorized.
X509_free(cert);
return;
}
// both certificate and access manager are present
auto decision = accessManager.verify(peerAddress);
if (decision != Decision.SKIP) {
X509_free(cert);
if (decision != Decision.ALLOW) {
throw new TSSLException("Authorize: Access denied based on remote IP.");
}
return;
}
// Check subjectAltName(s), if present.
auto alternatives = cast(STACK_OF!(GENERAL_NAME)*)
X509_get_ext_d2i(cert, NID_subject_alt_name, null, null);
if (alternatives != null) {
auto count = sk_GENERAL_NAME_num(alternatives);
for (int i = 0; decision == Decision.SKIP && i < count; i++) {
auto name = sk_GENERAL_NAME_value(alternatives, i);
if (name is null) {
continue;
}
auto data = ASN1_STRING_data(name.d.ia5);
auto length = ASN1_STRING_length(name.d.ia5);
switch (name.type) {
case GENERAL_NAME.GEN_DNS:
decision = accessManager.verify(hostName, cast(char[])data[0 .. length]);
break;
case GENERAL_NAME.GEN_IPADD:
decision = accessManager.verify(peerAddress, data[0 .. length]);
break;
default:
// Do nothing.
}
}
// DMD @@BUG@@: Empty template arguments parens should not be needed.
sk_GENERAL_NAME_pop_free!()(alternatives, &GENERAL_NAME_free);
}
// If we are alredy done, return.
if (decision != Decision.SKIP) {
X509_free(cert);
if (decision != Decision.ALLOW) {
throw new TSSLException("Authorize: Access denied.");
}
return;
}
// Check commonName.
auto name = X509_get_subject_name(cert);
if (name !is null) {
X509_NAME_ENTRY* entry;
char* utf8;
int last = -1;
while (decision == Decision.SKIP) {
last = X509_NAME_get_index_by_NID(name, NID_commonName, last);
if (last == -1)
break;
entry = X509_NAME_get_entry(name, last);
if (entry is null)
continue;
auto common = X509_NAME_ENTRY_get_data(entry);
auto size = ASN1_STRING_to_UTF8(&utf8, common);
decision = accessManager.verify(hostName, utf8[0 .. size]);
CRYPTO_free(utf8);
}
}
X509_free(cert);
if (decision != Decision.ALLOW) {
throw new TSSLException("Authorize: Could not authorize peer.");
}
}
/*
* OpenSSL error information used for storing D exceptions on the OpenSSL
* error stack.
*/
enum ERR_LIB_D_EXCEPTION = ERR_LIB_USER;
enum ERR_F_D_EXCEPTION = 0; // function id - what to use here?
enum ERR_R_D_EXCEPTION = 1234; // 99 and above are reserved for applications
enum ERR_FILE_D_EXCEPTION = "d_exception";
enum ERR_LINE_D_EXCEPTION = 0;
enum ERR_FLAGS_D_EXCEPTION = 0;
/**
* Returns an exception for the last.
*
* Params:
* location = An optional "location" to add to the error message (typically
* the last SSL API call).
*/
Exception getSSLException(string location = null, string clientFile = __FILE__,
size_t clientLine = __LINE__
) {
// We can return either an exception saved from D BIO code, or a "true"
// OpenSSL error. Because there can possibly be more than one error on the
// error stack, we have to fetch all of them, and pick the last, i.e. newest
// one. We concatenate multiple successive OpenSSL error messages into a
// single one, but always just return the last D expcetion.
string message; // Probably better use an Appender here.
bool hadMessage;
Exception exception;
void initMessage() {
message.destroy();
hadMessage = false;
if (!location.empty) {
message ~= location;
message ~= ": ";
}
}
initMessage();
auto errn = errno;
const(char)* file = void;
int line = void;
const(char)* data = void;
int flags = void;
c_ulong code = void;
while ((code = ERR_get_error_line_data(&file, &line, &data, &flags)) != 0) {
if (ERR_GET_REASON(code) == ERR_R_D_EXCEPTION) {
initMessage();
GC.removeRoot(cast(void*)data);
exception = cast(Exception)data;
} else {
exception = null;
if (hadMessage) {
message ~= ", ";
}
auto reason = ERR_reason_error_string(code);
if (reason) {
message ~= "SSL error: " ~ to!string(reason);
} else {
message ~= "SSL error #" ~ to!string(code);
}
hadMessage = true;
}
}
// If the last item from the stack was a D exception, throw it.
if (exception) return exception;
// We are dealing with an OpenSSL error that doesn't root in a D exception.
if (!hadMessage) {
// If we didn't get an actual error from the stack yet, try errno.
string errnString;
if (errn != 0) {
errnString = to!string(strerror(errn));
}
if (errnString.empty) {
message ~= "Unknown error";
} else {
message ~= errnString;
}
}
message ~= ".";
return new TSSLException(message, clientFile, clientLine);
}

View file

@ -0,0 +1,190 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/**
* Provides a SSL BIO implementation wrapping a Thrift transport.
*
* This way, SSL I/O can be relayed over Thrift transport without introducing
* an additional layer of buffering, especially for the non-blocking
* transports.
*
* For the Thrift transport incarnations of the SSL entities, "tt" is used as
* prefix for clarity.
*/
module thrift.internal.ssl_bio;
import core.stdc.config;
import core.stdc.string : strlen;
import core.memory : GC;
import deimos.openssl.bio;
import deimos.openssl.err;
import thrift.base;
import thrift.internal.ssl;
import thrift.transport.base;
/**
* Creates an SSL BIO object wrapping the given transport.
*
* Exceptions thrown by the transport are pushed onto the OpenSSL error stack,
* using the location/reason values from thrift.internal.ssl.ERR_*_D_EXCEPTION.
*
* The transport is assumed to be ready for reading and writing when the BIO
* functions are called, it is not opened by the implementation.
*
* Params:
* transport = The transport to wrap.
* closeTransport = Whether the close the transport when the SSL BIO is
* closed.
*/
BIO* createTTransportBIO(TTransport transport, bool closeTransport) {
auto result = BIO_new(cast(BIO_METHOD*)&ttBioMethod);
if (!result) return null;
GC.addRoot(cast(void*)transport);
BIO_set_fd(result, closeTransport, cast(c_long)cast(void*)transport);
return result;
}
private {
// Helper to get the Thrift transport assigned with the given BIO.
TTransport trans(BIO* b) nothrow {
auto result = cast(TTransport)b.ptr;
assert(result);
return result;
}
void setError(Exception e) nothrow {
ERR_put_error(ERR_LIB_D_EXCEPTION, ERR_F_D_EXCEPTION, ERR_R_D_EXCEPTION,
ERR_FILE_D_EXCEPTION, ERR_LINE_D_EXCEPTION);
try { GC.addRoot(cast(void*)e); } catch {}
ERR_set_error_data(cast(char*)e, ERR_FLAGS_D_EXCEPTION);
}
extern(C) int ttWrite(BIO* b, const(char)* data, int length) nothrow {
assert(b);
if (!data || length <= 0) return 0;
try {
trans(b).write((cast(ubyte*)data)[0 .. length]);
return length;
} catch (Exception e) {
setError(e);
return -1;
}
}
extern(C) int ttRead(BIO* b, char* data, int length) nothrow {
assert(b);
if (!data || length <= 0) return 0;
try {
return cast(int)trans(b).read((cast(ubyte*)data)[0 .. length]);
} catch (Exception e) {
setError(e);
return -1;
}
}
extern(C) int ttPuts(BIO* b, const(char)* str) nothrow {
return ttWrite(b, str, cast(int)strlen(str));
}
extern(C) c_long ttCtrl(BIO* b, int cmd, c_long num, void* ptr) nothrow {
assert(b);
switch (cmd) {
case BIO_C_SET_FD:
// Note that close flag and "fd" are actually reversed here because we
// need 64 bit width for the pointer should probably drop BIO_set_fd
// altogether.
ttDestroy(b);
b.ptr = cast(void*)num;
b.shutdown = cast(int)ptr;
b.init_ = 1;
return 1;
case BIO_C_GET_FD:
if (!b.init_) return -1;
*(cast(void**)ptr) = b.ptr;
return cast(c_long)b.ptr;
case BIO_CTRL_GET_CLOSE:
return b.shutdown;
case BIO_CTRL_SET_CLOSE:
b.shutdown = cast(int)num;
return 1;
case BIO_CTRL_FLUSH:
try {
trans(b).flush();
return 1;
} catch (Exception e) {
setError(e);
return -1;
}
case BIO_CTRL_DUP:
// Seems like we have nothing to do on duplication, but couldn't find
// any documentation if this actually ever happens during normal SSL
// usage.
return 1;
default:
return 0;
}
}
extern(C) int ttCreate(BIO* b) nothrow {
assert(b);
b.init_ = 0;
b.num = 0; // User-defined number field, unused here.
b.ptr = null;
b.flags = 0;
return 1;
}
extern(C) int ttDestroy(BIO* b) nothrow {
if (!b) return 0;
int rc = 1;
if (b.shutdown) {
if (b.init_) {
try {
trans(b).close();
GC.removeRoot(cast(void*)trans(b));
b.ptr = null;
} catch (Exception e) {
setError(e);
rc = -1;
}
}
b.init_ = 0;
b.flags = 0;
}
return rc;
}
immutable BIO_METHOD ttBioMethod = {
BIO_TYPE_SOURCE_SINK,
"TTransport",
&ttWrite,
&ttRead,
&ttPuts,
null, // gets
&ttCtrl,
&ttCreate,
&ttDestroy,
null // callback_ctrl
};
}

View file

@ -0,0 +1,183 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.internal.test.protocol;
import std.exception;
import thrift.transport.memory;
import thrift.protocol.base;
version (unittest):
void testContainerSizeLimit(Protocol)() if (isTProtocol!Protocol) {
auto buffer = new TMemoryBuffer;
auto prot = new Protocol(buffer);
// Make sure reading fails if a container larger than the size limit is read.
prot.containerSizeLimit = 3;
{
prot.writeListBegin(TList(TType.I32, 4));
prot.writeI32(0); // Make sure size can be read e.g. for JSON protocol.
prot.reset();
auto e = cast(TProtocolException)collectException(prot.readListBegin());
enforce(e && e.type == TProtocolException.Type.SIZE_LIMIT);
prot.reset();
buffer.reset();
}
{
prot.writeMapBegin(TMap(TType.I32, TType.I32, 4));
prot.writeI32(0); // Make sure size can be read e.g. for JSON protocol.
prot.reset();
auto e = cast(TProtocolException)collectException(prot.readMapBegin());
enforce(e && e.type == TProtocolException.Type.SIZE_LIMIT);
prot.reset();
buffer.reset();
}
{
prot.writeSetBegin(TSet(TType.I32, 4));
prot.writeI32(0); // Make sure size can be read e.g. for JSON protocol.
prot.reset();
auto e = cast(TProtocolException)collectException(prot.readSetBegin());
enforce(e && e.type == TProtocolException.Type.SIZE_LIMIT);
prot.reset();
buffer.reset();
}
// Make sure reading works if the containers are smaller than the limit or
// no limit is set.
foreach (limit; [3, 0, -1]) {
prot.containerSizeLimit = limit;
{
prot.writeListBegin(TList(TType.I32, 2));
prot.writeI32(0);
prot.writeI32(1);
prot.writeListEnd();
prot.reset();
auto list = prot.readListBegin();
enforce(list.elemType == TType.I32);
enforce(list.size == 2);
enforce(prot.readI32() == 0);
enforce(prot.readI32() == 1);
prot.readListEnd();
prot.reset();
buffer.reset();
}
{
prot.writeMapBegin(TMap(TType.I32, TType.I32, 2));
prot.writeI32(0);
prot.writeI32(1);
prot.writeI32(2);
prot.writeI32(3);
prot.writeMapEnd();
prot.reset();
auto map = prot.readMapBegin();
enforce(map.keyType == TType.I32);
enforce(map.valueType == TType.I32);
enforce(map.size == 2);
enforce(prot.readI32() == 0);
enforce(prot.readI32() == 1);
enforce(prot.readI32() == 2);
enforce(prot.readI32() == 3);
prot.readMapEnd();
prot.reset();
buffer.reset();
}
{
prot.writeSetBegin(TSet(TType.I32, 2));
prot.writeI32(0);
prot.writeI32(1);
prot.writeSetEnd();
prot.reset();
auto set = prot.readSetBegin();
enforce(set.elemType == TType.I32);
enforce(set.size == 2);
enforce(prot.readI32() == 0);
enforce(prot.readI32() == 1);
prot.readSetEnd();
prot.reset();
buffer.reset();
}
}
}
void testStringSizeLimit(Protocol)() if (isTProtocol!Protocol) {
auto buffer = new TMemoryBuffer;
auto prot = new Protocol(buffer);
// Make sure reading fails if a string larger than the size limit is read.
prot.stringSizeLimit = 3;
{
prot.writeString("asdf");
prot.reset();
auto e = cast(TProtocolException)collectException(prot.readString());
enforce(e && e.type == TProtocolException.Type.SIZE_LIMIT);
prot.reset();
buffer.reset();
}
{
prot.writeBinary([1, 2, 3, 4]);
prot.reset();
auto e = cast(TProtocolException)collectException(prot.readBinary());
enforce(e && e.type == TProtocolException.Type.SIZE_LIMIT);
prot.reset();
buffer.reset();
}
// Make sure reading works if the containers are smaller than the limit or
// no limit is set.
foreach (limit; [3, 0, -1]) {
prot.containerSizeLimit = limit;
{
prot.writeString("as");
prot.reset();
enforce(prot.readString() == "as");
prot.reset();
buffer.reset();
}
{
prot.writeBinary([1, 2]);
prot.reset();
enforce(prot.readBinary() == [1, 2]);
prot.reset();
buffer.reset();
}
}
}

View file

@ -0,0 +1,110 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.internal.test.server;
import core.sync.condition;
import core.sync.mutex;
import core.thread : Thread;
import std.datetime;
import std.exception : enforce;
import std.typecons : WhiteHole;
import std.variant : Variant;
import thrift.protocol.base;
import thrift.protocol.binary;
import thrift.protocol.processor;
import thrift.server.base;
import thrift.server.transport.socket;
import thrift.transport.base;
import thrift.util.cancellation;
version(unittest):
/**
* Tests if serving is stopped correctly if the cancellation passed to serve()
* is triggered.
*
* Because the tests are run many times in a loop, this is indirectly also a
* test whether socket, etc. handles are cleaned up correctly, because the
* application will likely run out of handles otherwise.
*/
void testServeCancel(Server)(void delegate(Server) serverSetup = null) if (
is(Server : TServer)
) {
auto proc = new WhiteHole!TProcessor;
auto tf = new TTransportFactory;
auto pf = new TBinaryProtocolFactory!();
// Need a special case for TNonblockingServer which doesn't use
// TServerTransport.
static if (__traits(compiles, new Server(proc, 0, tf, pf))) {
auto server = new Server(proc, 0, tf, pf);
} else {
auto server = new Server(proc, new TServerSocket(0), tf, pf);
}
// On Windows, we use TCP sockets to replace socketpair(). Since they stay
// in TIME_WAIT for some time even if they are properly closed, we have to use
// a lower number of iterations to avoid running out of ports/buffer space.
version (Windows) {
enum ITERATIONS = 100;
} else {
enum ITERATIONS = 10000;
}
if (serverSetup) serverSetup(server);
auto servingMutex = new Mutex;
auto servingCondition = new Condition(servingMutex);
auto doneMutex = new Mutex;
auto doneCondition = new Condition(doneMutex);
class CancellingHandler : TServerEventHandler {
void preServe() {
synchronized (servingMutex) {
servingCondition.notifyAll();
}
}
Variant createContext(TProtocol input, TProtocol output) { return Variant.init; }
void deleteContext(Variant serverContext, TProtocol input, TProtocol output) {}
void preProcess(Variant serverContext, TTransport transport) {}
}
server.eventHandler = new CancellingHandler;
foreach (i; 0 .. ITERATIONS) {
synchronized (servingMutex) {
auto cancel = new TCancellationOrigin;
synchronized (doneMutex) {
auto serverThread = new Thread({
server.serve(cancel);
synchronized (doneMutex) {
doneCondition.notifyAll();
}
});
serverThread.isDaemon = true;
serverThread.start();
servingCondition.wait();
cancel.trigger();
enforce(doneCondition.wait(dur!"msecs"(3*1000)));
serverThread.join();
}
}
}
}

View file

@ -0,0 +1,33 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.internal.traits;
import std.traits;
/**
* Adds »nothrow« to the type of the passed function pointer/delegate, if it
* is not already present.
*
* Technically, assumeNothrow just performs a cast, but using it has the
* advantage of being explicitly about the operation that is performed.
*/
auto assumeNothrow(T)(T t) if (isFunctionPointer!T || isDelegate!T) {
enum attrs = functionAttributes!T | FunctionAttribute.nothrow_;
return cast(SetFunctionAttributes!(T, functionLinkage!T, attrs)) t;
}

View file

@ -0,0 +1,449 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/**
* Defines the basic interface for a Thrift protocol and associated exception
* types.
*
* Most parts of the protocol API are typically not used in client code, as
* the actual serialization code is generated by thrift.codegen.* the only
* interesting thing usually is that there are protocols which can be created
* from transports and passed around.
*/
module thrift.protocol.base;
import thrift.base;
import thrift.transport.base;
/**
* The field types Thrift protocols support.
*/
enum TType : byte {
STOP = 0, /// Used to mark the end of a sequence of fields.
VOID = 1, ///
BOOL = 2, ///
BYTE = 3, ///
DOUBLE = 4, ///
I16 = 6, ///
I32 = 8, ///
I64 = 10, ///
STRING = 11, ///
STRUCT = 12, ///
MAP = 13, ///
SET = 14, ///
LIST = 15 ///
}
/**
* Types of Thrift RPC messages.
*/
enum TMessageType : byte {
CALL = 1, /// Call of a normal, two-way RPC method.
REPLY = 2, /// Reply to a normal method call.
EXCEPTION = 3, /// Reply to a method call if target raised a TApplicationException.
ONEWAY = 4 /// Call of a one-way RPC method which is not followed by a reply.
}
/**
* Descriptions of Thrift entities.
*/
struct TField {
string name;
TType type;
short id;
}
/// ditto
struct TList {
TType elemType;
size_t size;
}
/// ditto
struct TMap {
TType keyType;
TType valueType;
size_t size;
}
/// ditto
struct TMessage {
string name;
TMessageType type;
int seqid;
}
/// ditto
struct TSet {
TType elemType;
size_t size;
}
/// ditto
struct TStruct {
string name;
}
/**
* Interface for a Thrift protocol implementation. Essentially, it defines
* a way of reading and writing all the base types, plus a mechanism for
* writing out structs with indexed fields.
*
* TProtocol objects should not be shared across multiple encoding contexts,
* as they may need to maintain internal state in some protocols (e.g. JSON).
* Note that is is acceptable for the TProtocol module to do its own internal
* buffered reads/writes to the underlying TTransport where appropriate (i.e.
* when parsing an input XML stream, reading could be batched rather than
* looking ahead character by character for a close tag).
*/
interface TProtocol {
/// The underlying transport used by the protocol.
TTransport transport() @property;
/*
* Writing methods.
*/
void writeBool(bool b); ///
void writeByte(byte b); ///
void writeI16(short i16); ///
void writeI32(int i32); ///
void writeI64(long i64); ///
void writeDouble(double dub); ///
void writeString(string str); ///
void writeBinary(ubyte[] buf); ///
void writeMessageBegin(TMessage message); ///
void writeMessageEnd(); ///
void writeStructBegin(TStruct tstruct); ///
void writeStructEnd(); ///
void writeFieldBegin(TField field); ///
void writeFieldEnd(); ///
void writeFieldStop(); ///
void writeListBegin(TList list); ///
void writeListEnd(); ///
void writeMapBegin(TMap map); ///
void writeMapEnd(); ///
void writeSetBegin(TSet set); ///
void writeSetEnd(); ///
/*
* Reading methods.
*/
bool readBool(); ///
byte readByte(); ///
short readI16(); ///
int readI32(); ///
long readI64(); ///
double readDouble(); ///
string readString(); ///
ubyte[] readBinary(); ///
TMessage readMessageBegin(); ///
void readMessageEnd(); ///
TStruct readStructBegin(); ///
void readStructEnd(); ///
TField readFieldBegin(); ///
void readFieldEnd(); ///
TList readListBegin(); ///
void readListEnd(); ///
TMap readMapBegin(); ///
void readMapEnd(); ///
TSet readSetBegin(); ///
void readSetEnd(); ///
/**
* Reset any internal state back to a blank slate, if the protocol is
* stateful.
*/
void reset();
}
/**
* true if T is a TProtocol.
*/
template isTProtocol(T) {
enum isTProtocol = is(T : TProtocol);
}
unittest {
static assert(isTProtocol!TProtocol);
static assert(!isTProtocol!void);
}
/**
* Creates a protocol operating on a given transport.
*/
interface TProtocolFactory {
///
TProtocol getProtocol(TTransport trans);
}
/**
* A protocol-level exception.
*/
class TProtocolException : TException {
/// The possible exception types.
enum Type {
UNKNOWN, ///
INVALID_DATA, ///
NEGATIVE_SIZE, ///
SIZE_LIMIT, ///
BAD_VERSION, ///
NOT_IMPLEMENTED, ///
DEPTH_LIMIT ///
}
///
this(Type type, string file = __FILE__, size_t line = __LINE__, Throwable next = null) {
static string msgForType(Type type) {
switch (type) {
case Type.UNKNOWN: return "Unknown protocol exception";
case Type.INVALID_DATA: return "Invalid data";
case Type.NEGATIVE_SIZE: return "Negative size";
case Type.SIZE_LIMIT: return "Exceeded size limit";
case Type.BAD_VERSION: return "Invalid version";
case Type.NOT_IMPLEMENTED: return "Not implemented";
case Type.DEPTH_LIMIT: return "Exceeded size limit";
default: return "(Invalid exception type)";
}
}
this(msgForType(type), type, file, line, next);
}
///
this(string msg, string file = __FILE__, size_t line = __LINE__,
Throwable next = null)
{
this(msg, Type.UNKNOWN, file, line, next);
}
///
this(string msg, Type type, string file = __FILE__, size_t line = __LINE__,
Throwable next = null)
{
super(msg, file, line, next);
type_ = type;
}
///
Type type() const @property {
return type_;
}
protected:
Type type_;
}
/**
* Skips a field of the given type on the protocol.
*
* The main purpose of skip() is to allow treating struct and container types,
* (where multiple primitive types have to be skipped) the same as scalar types
* in generated code.
*/
void skip(Protocol)(Protocol prot, TType type) if (is(Protocol : TProtocol)) {
final switch (type) {
case TType.BOOL:
prot.readBool();
break;
case TType.BYTE:
prot.readByte();
break;
case TType.I16:
prot.readI16();
break;
case TType.I32:
prot.readI32();
break;
case TType.I64:
prot.readI64();
break;
case TType.DOUBLE:
prot.readDouble();
break;
case TType.STRING:
prot.readBinary();
break;
case TType.STRUCT:
prot.readStructBegin();
while (true) {
auto f = prot.readFieldBegin();
if (f.type == TType.STOP) break;
skip(prot, f.type);
prot.readFieldEnd();
}
prot.readStructEnd();
break;
case TType.LIST:
auto l = prot.readListBegin();
foreach (i; 0 .. l.size) {
skip(prot, l.elemType);
}
prot.readListEnd();
break;
case TType.MAP:
auto m = prot.readMapBegin();
foreach (i; 0 .. m.size) {
skip(prot, m.keyType);
skip(prot, m.valueType);
}
prot.readMapEnd();
break;
case TType.SET:
auto s = prot.readSetBegin();
foreach (i; 0 .. s.size) {
skip(prot, s.elemType);
}
prot.readSetEnd();
break;
case TType.STOP: goto case;
case TType.VOID:
assert(false, "Invalid field type passed.");
}
}
/**
* Application-level exception.
*
* It is thrown if an RPC call went wrong on the application layer, e.g. if
* the receiver does not know the method name requested or a method invoked by
* the service processor throws an exception not part of the Thrift API.
*/
class TApplicationException : TException {
/// The possible exception types.
enum Type {
UNKNOWN = 0, ///
UNKNOWN_METHOD = 1, ///
INVALID_MESSAGE_TYPE = 2, ///
WRONG_METHOD_NAME = 3, ///
BAD_SEQUENCE_ID = 4, ///
MISSING_RESULT = 5, ///
INTERNAL_ERROR = 6, ///
PROTOCOL_ERROR = 7, ///
INVALID_TRANSFORM = 8, ///
INVALID_PROTOCOL = 9, ///
UNSUPPORTED_CLIENT_TYPE = 10 ///
}
///
this(Type type, string file = __FILE__, size_t line = __LINE__, Throwable next = null) {
static string msgForType(Type type) {
switch (type) {
case Type.UNKNOWN: return "Unknown application exception";
case Type.UNKNOWN_METHOD: return "Unknown method";
case Type.INVALID_MESSAGE_TYPE: return "Invalid message type";
case Type.WRONG_METHOD_NAME: return "Wrong method name";
case Type.BAD_SEQUENCE_ID: return "Bad sequence identifier";
case Type.MISSING_RESULT: return "Missing result";
case Type.INTERNAL_ERROR: return "Internal error";
case Type.PROTOCOL_ERROR: return "Protocol error";
case Type.INVALID_TRANSFORM: return "Invalid transform";
case Type.INVALID_PROTOCOL: return "Invalid protocol";
case Type.UNSUPPORTED_CLIENT_TYPE: return "Unsupported client type";
default: return "(Invalid exception type)";
}
}
this(msgForType(type), type, file, line, next);
}
///
this(string msg, string file = __FILE__, size_t line = __LINE__,
Throwable next = null)
{
this(msg, Type.UNKNOWN, file, line, next);
}
///
this(string msg, Type type, string file = __FILE__, size_t line = __LINE__,
Throwable next = null)
{
super(msg, file, line, next);
type_ = type;
}
///
Type type() @property const {
return type_;
}
// TODO: Replace hand-written read()/write() with thrift.codegen templates.
///
void read(TProtocol iprot) {
iprot.readStructBegin();
while (true) {
auto f = iprot.readFieldBegin();
if (f.type == TType.STOP) break;
switch (f.id) {
case 1:
if (f.type == TType.STRING) {
msg = iprot.readString();
} else {
skip(iprot, f.type);
}
break;
case 2:
if (f.type == TType.I32) {
type_ = cast(Type)iprot.readI32();
} else {
skip(iprot, f.type);
}
break;
default:
skip(iprot, f.type);
break;
}
}
iprot.readStructEnd();
}
///
void write(TProtocol oprot) const {
oprot.writeStructBegin(TStruct("TApplicationException"));
if (msg != null) {
oprot.writeFieldBegin(TField("message", TType.STRING, 1));
oprot.writeString(msg);
oprot.writeFieldEnd();
}
oprot.writeFieldBegin(TField("type", TType.I32, 2));
oprot.writeI32(type_);
oprot.writeFieldEnd();
oprot.writeFieldStop();
oprot.writeStructEnd();
}
private:
Type type_;
}

View file

@ -0,0 +1,414 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.protocol.binary;
import std.array : uninitializedArray;
import std.typetuple : allSatisfy, TypeTuple;
import thrift.protocol.base;
import thrift.transport.base;
import thrift.internal.endian;
/**
* TProtocol implementation of the Binary Thrift protocol.
*/
final class TBinaryProtocol(Transport = TTransport) if (
isTTransport!Transport
) : TProtocol {
/**
* Constructs a new instance.
*
* Params:
* trans = The transport to use.
* containerSizeLimit = If positive, the container size is limited to the
* given number of items.
* stringSizeLimit = If positive, the string length is limited to the
* given number of bytes.
* strictRead = If false, old peers which do not include the protocol
* version are tolerated.
* strictWrite = Whether to include the protocol version in the header.
*/
this(Transport trans, int containerSizeLimit = 0, int stringSizeLimit = 0,
bool strictRead = false, bool strictWrite = true
) {
trans_ = trans;
this.containerSizeLimit = containerSizeLimit;
this.stringSizeLimit = stringSizeLimit;
this.strictRead = strictRead;
this.strictWrite = strictWrite;
}
Transport transport() @property {
return trans_;
}
void reset() {}
/**
* If false, old peers which do not include the protocol version in the
* message header are tolerated.
*
* Defaults to false.
*/
bool strictRead;
/**
* Whether to include the protocol version in the message header (older
* versions didn't).
*
* Defaults to true.
*/
bool strictWrite;
/**
* If positive, limits the number of items of deserialized containers to the
* given amount.
*
* This is useful to avoid allocating excessive amounts of memory when broken
* data is received. If the limit is exceeded, a SIZE_LIMIT-type
* TProtocolException is thrown.
*
* Defaults to zero (no limit).
*/
int containerSizeLimit;
/**
* If positive, limits the length of deserialized strings/binary data to the
* given number of bytes.
*
* This is useful to avoid allocating excessive amounts of memory when broken
* data is received. If the limit is exceeded, a SIZE_LIMIT-type
* TProtocolException is thrown.
*
* Defaults to zero (no limit).
*/
int stringSizeLimit;
/*
* Writing methods.
*/
void writeBool(bool b) {
writeByte(b ? 1 : 0);
}
void writeByte(byte b) {
trans_.write((cast(ubyte*)&b)[0 .. 1]);
}
void writeI16(short i16) {
short net = hostToNet(i16);
trans_.write((cast(ubyte*)&net)[0 .. 2]);
}
void writeI32(int i32) {
int net = hostToNet(i32);
trans_.write((cast(ubyte*)&net)[0 .. 4]);
}
void writeI64(long i64) {
long net = hostToNet(i64);
trans_.write((cast(ubyte*)&net)[0 .. 8]);
}
void writeDouble(double dub) {
static assert(double.sizeof == ulong.sizeof);
auto bits = hostToNet(*cast(ulong*)(&dub));
trans_.write((cast(ubyte*)&bits)[0 .. 8]);
}
void writeString(string str) {
writeBinary(cast(ubyte[])str);
}
void writeBinary(ubyte[] buf) {
assert(buf.length <= int.max);
writeI32(cast(int)buf.length);
trans_.write(buf);
}
void writeMessageBegin(TMessage message) {
if (strictWrite) {
int versn = VERSION_1 | message.type;
writeI32(versn);
writeString(message.name);
writeI32(message.seqid);
} else {
writeString(message.name);
writeByte(message.type);
writeI32(message.seqid);
}
}
void writeMessageEnd() {}
void writeStructBegin(TStruct tstruct) {}
void writeStructEnd() {}
void writeFieldBegin(TField field) {
writeByte(field.type);
writeI16(field.id);
}
void writeFieldEnd() {}
void writeFieldStop() {
writeByte(TType.STOP);
}
void writeListBegin(TList list) {
assert(list.size <= int.max);
writeByte(list.elemType);
writeI32(cast(int)list.size);
}
void writeListEnd() {}
void writeMapBegin(TMap map) {
assert(map.size <= int.max);
writeByte(map.keyType);
writeByte(map.valueType);
writeI32(cast(int)map.size);
}
void writeMapEnd() {}
void writeSetBegin(TSet set) {
assert(set.size <= int.max);
writeByte(set.elemType);
writeI32(cast(int)set.size);
}
void writeSetEnd() {}
/*
* Reading methods.
*/
bool readBool() {
return readByte() != 0;
}
byte readByte() {
ubyte[1] b = void;
trans_.readAll(b);
return cast(byte)b[0];
}
short readI16() {
IntBuf!short b = void;
trans_.readAll(b.bytes);
return netToHost(b.value);
}
int readI32() {
IntBuf!int b = void;
trans_.readAll(b.bytes);
return netToHost(b.value);
}
long readI64() {
IntBuf!long b = void;
trans_.readAll(b.bytes);
return netToHost(b.value);
}
double readDouble() {
IntBuf!long b = void;
trans_.readAll(b.bytes);
b.value = netToHost(b.value);
return *cast(double*)(&b.value);
}
string readString() {
return cast(string)readBinary();
}
ubyte[] readBinary() {
return readBinaryBody(readSize(stringSizeLimit));
}
TMessage readMessageBegin() {
TMessage msg = void;
int size = readI32();
if (size < 0) {
int versn = size & VERSION_MASK;
if (versn != VERSION_1) {
throw new TProtocolException("Bad protocol version.",
TProtocolException.Type.BAD_VERSION);
}
msg.type = cast(TMessageType)(size & MESSAGE_TYPE_MASK);
msg.name = readString();
msg.seqid = readI32();
} else {
if (strictRead) {
throw new TProtocolException(
"Protocol version missing, old client?",
TProtocolException.Type.BAD_VERSION);
} else {
if (size < 0) {
throw new TProtocolException(TProtocolException.Type.NEGATIVE_SIZE);
}
msg.name = cast(string)readBinaryBody(size);
msg.type = cast(TMessageType)(readByte());
msg.seqid = readI32();
}
}
return msg;
}
void readMessageEnd() {}
TStruct readStructBegin() {
return TStruct();
}
void readStructEnd() {}
TField readFieldBegin() {
TField f = void;
f.name = null;
f.type = cast(TType)readByte();
if (f.type == TType.STOP) return f;
f.id = readI16();
return f;
}
void readFieldEnd() {}
TList readListBegin() {
return TList(cast(TType)readByte(), readSize(containerSizeLimit));
}
void readListEnd() {}
TMap readMapBegin() {
return TMap(cast(TType)readByte(), cast(TType)readByte(),
readSize(containerSizeLimit));
}
void readMapEnd() {}
TSet readSetBegin() {
return TSet(cast(TType)readByte(), readSize(containerSizeLimit));
}
void readSetEnd() {}
private:
ubyte[] readBinaryBody(int size) {
if (size == 0) {
return null;
}
auto buf = uninitializedArray!(ubyte[])(size);
trans_.readAll(buf);
return buf;
}
int readSize(int limit) {
auto size = readI32();
if (size < 0) {
throw new TProtocolException(TProtocolException.Type.NEGATIVE_SIZE);
} else if (limit > 0 && size > limit) {
throw new TProtocolException(TProtocolException.Type.SIZE_LIMIT);
}
return size;
}
enum MESSAGE_TYPE_MASK = 0x000000ff;
enum VERSION_MASK = 0xffff0000;
enum VERSION_1 = 0x80010000;
Transport trans_;
}
/**
* TBinaryProtocol construction helper to avoid having to explicitly specify
* the transport type, i.e. to allow the constructor being called using IFTI
* (see $(LINK2 http://d.puremagic.com/issues/show_bug.cgi?id=6082, D Bugzilla
* enhancement requet 6082)).
*/
TBinaryProtocol!Transport tBinaryProtocol(Transport)(Transport trans,
int containerSizeLimit = 0, int stringSizeLimit = 0,
bool strictRead = false, bool strictWrite = true
) if (isTTransport!Transport) {
return new TBinaryProtocol!Transport(trans, containerSizeLimit,
stringSizeLimit, strictRead, strictWrite);
}
unittest {
import std.exception;
import thrift.transport.memory;
// Check the message header format.
auto buf = new TMemoryBuffer;
auto binary = tBinaryProtocol(buf);
binary.writeMessageBegin(TMessage("foo", TMessageType.CALL, 0));
auto header = new ubyte[15];
buf.readAll(header);
enforce(header == [
128, 1, 0, 1, // Version 1, TMessageType.CALL
0, 0, 0, 3, // Method name length
102, 111, 111, // Method name ("foo")
0, 0, 0, 0, // Sequence id
]);
}
unittest {
import thrift.internal.test.protocol;
testContainerSizeLimit!(TBinaryProtocol!())();
testStringSizeLimit!(TBinaryProtocol!())();
}
/**
* TProtocolFactory creating a TBinaryProtocol instance for passed in
* transports.
*
* The optional Transports template tuple parameter can be used to specify
* one or more TTransport implementations to specifically instantiate
* TBinaryProtocol for. If the actual transport types encountered at
* runtime match one of the transports in the list, a specialized protocol
* instance is created. Otherwise, a generic TTransport version is used.
*/
class TBinaryProtocolFactory(Transports...) if (
allSatisfy!(isTTransport, Transports)
) : TProtocolFactory {
///
this (int containerSizeLimit = 0, int stringSizeLimit = 0,
bool strictRead = false, bool strictWrite = true
) {
strictRead_ = strictRead;
strictWrite_ = strictWrite;
containerSizeLimit_ = containerSizeLimit;
stringSizeLimit_ = stringSizeLimit;
}
TProtocol getProtocol(TTransport trans) const {
foreach (Transport; TypeTuple!(Transports, TTransport)) {
auto concreteTrans = cast(Transport)trans;
if (concreteTrans) {
return new TBinaryProtocol!Transport(concreteTrans,
containerSizeLimit_, stringSizeLimit_, strictRead_, strictWrite_);
}
}
throw new TProtocolException(
"Passed null transport to TBinaryProtocolFactoy.");
}
protected:
bool strictRead_;
bool strictWrite_;
int containerSizeLimit_;
int stringSizeLimit_;
}

View file

@ -0,0 +1,698 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.protocol.compact;
import std.array : uninitializedArray;
import std.typetuple : allSatisfy, TypeTuple;
import thrift.protocol.base;
import thrift.transport.base;
import thrift.internal.endian;
/**
* D implementation of the Compact protocol.
*
* See THRIFT-110 for a protocol description. This implementation is based on
* the C++ one.
*/
final class TCompactProtocol(Transport = TTransport) if (
isTTransport!Transport
) : TProtocol {
/**
* Constructs a new instance.
*
* Params:
* trans = The transport to use.
* containerSizeLimit = If positive, the container size is limited to the
* given number of items.
* stringSizeLimit = If positive, the string length is limited to the
* given number of bytes.
*/
this(Transport trans, int containerSizeLimit = 0, int stringSizeLimit = 0) {
trans_ = trans;
this.containerSizeLimit = containerSizeLimit;
this.stringSizeLimit = stringSizeLimit;
}
Transport transport() @property {
return trans_;
}
void reset() {
lastFieldId_ = 0;
fieldIdStack_ = null;
booleanField_ = TField.init;
hasBoolValue_ = false;
}
/**
* If positive, limits the number of items of deserialized containers to the
* given amount.
*
* This is useful to avoid allocating excessive amounts of memory when broken
* data is received. If the limit is exceeded, a SIZE_LIMIT-type
* TProtocolException is thrown.
*
* Defaults to zero (no limit).
*/
int containerSizeLimit;
/**
* If positive, limits the length of deserialized strings/binary data to the
* given number of bytes.
*
* This is useful to avoid allocating excessive amounts of memory when broken
* data is received. If the limit is exceeded, a SIZE_LIMIT-type
* TProtocolException is thrown.
*
* Defaults to zero (no limit).
*/
int stringSizeLimit;
/*
* Writing methods.
*/
void writeBool(bool b) {
if (booleanField_.name !is null) {
// we haven't written the field header yet
writeFieldBeginInternal(booleanField_,
b ? CType.BOOLEAN_TRUE : CType.BOOLEAN_FALSE);
booleanField_.name = null;
} else {
// we're not part of a field, so just write the value
writeByte(b ? CType.BOOLEAN_TRUE : CType.BOOLEAN_FALSE);
}
}
void writeByte(byte b) {
trans_.write((cast(ubyte*)&b)[0..1]);
}
void writeI16(short i16) {
writeVarint32(i32ToZigzag(i16));
}
void writeI32(int i32) {
writeVarint32(i32ToZigzag(i32));
}
void writeI64(long i64) {
writeVarint64(i64ToZigzag(i64));
}
void writeDouble(double dub) {
ulong bits = hostToLe(*cast(ulong*)(&dub));
trans_.write((cast(ubyte*)&bits)[0 .. 8]);
}
void writeString(string str) {
writeBinary(cast(ubyte[])str);
}
void writeBinary(ubyte[] buf) {
assert(buf.length <= int.max);
writeVarint32(cast(int)buf.length);
trans_.write(buf);
}
void writeMessageBegin(TMessage msg) {
writeByte(cast(byte)PROTOCOL_ID);
writeByte(cast(byte)((VERSION_N & VERSION_MASK) |
((cast(int)msg.type << TYPE_SHIFT_AMOUNT) & TYPE_MASK)));
writeVarint32(msg.seqid);
writeString(msg.name);
}
void writeMessageEnd() {}
void writeStructBegin(TStruct tstruct) {
fieldIdStack_ ~= lastFieldId_;
lastFieldId_ = 0;
}
void writeStructEnd() {
lastFieldId_ = fieldIdStack_[$ - 1];
fieldIdStack_ = fieldIdStack_[0 .. $ - 1];
fieldIdStack_.assumeSafeAppend();
}
void writeFieldBegin(TField field) {
if (field.type == TType.BOOL) {
booleanField_.name = field.name;
booleanField_.type = field.type;
booleanField_.id = field.id;
} else {
return writeFieldBeginInternal(field);
}
}
void writeFieldEnd() {}
void writeFieldStop() {
writeByte(TType.STOP);
}
void writeListBegin(TList list) {
writeCollectionBegin(list.elemType, list.size);
}
void writeListEnd() {}
void writeMapBegin(TMap map) {
if (map.size == 0) {
writeByte(0);
} else {
assert(map.size <= int.max);
writeVarint32(cast(int)map.size);
writeByte(cast(byte)(toCType(map.keyType) << 4 | toCType(map.valueType)));
}
}
void writeMapEnd() {}
void writeSetBegin(TSet set) {
writeCollectionBegin(set.elemType, set.size);
}
void writeSetEnd() {}
/*
* Reading methods.
*/
bool readBool() {
if (hasBoolValue_ == true) {
hasBoolValue_ = false;
return boolValue_;
}
return readByte() == CType.BOOLEAN_TRUE;
}
byte readByte() {
ubyte[1] b = void;
trans_.readAll(b);
return cast(byte)b[0];
}
short readI16() {
return cast(short)zigzagToI32(readVarint32());
}
int readI32() {
return zigzagToI32(readVarint32());
}
long readI64() {
return zigzagToI64(readVarint64());
}
double readDouble() {
IntBuf!long b = void;
trans_.readAll(b.bytes);
b.value = leToHost(b.value);
return *cast(double*)(&b.value);
}
string readString() {
return cast(string)readBinary();
}
ubyte[] readBinary() {
auto size = readVarint32();
checkSize(size, stringSizeLimit);
if (size == 0) {
return null;
}
auto buf = uninitializedArray!(ubyte[])(size);
trans_.readAll(buf);
return buf;
}
TMessage readMessageBegin() {
TMessage msg = void;
auto protocolId = readByte();
if (protocolId != cast(byte)PROTOCOL_ID) {
throw new TProtocolException("Bad protocol identifier",
TProtocolException.Type.BAD_VERSION);
}
auto versionAndType = readByte();
auto ver = versionAndType & VERSION_MASK;
if (ver != VERSION_N) {
throw new TProtocolException("Bad protocol version",
TProtocolException.Type.BAD_VERSION);
}
msg.type = cast(TMessageType)((versionAndType >> TYPE_SHIFT_AMOUNT) & TYPE_BITS);
msg.seqid = readVarint32();
msg.name = readString();
return msg;
}
void readMessageEnd() {}
TStruct readStructBegin() {
fieldIdStack_ ~= lastFieldId_;
lastFieldId_ = 0;
return TStruct();
}
void readStructEnd() {
lastFieldId_ = fieldIdStack_[$ - 1];
fieldIdStack_ = fieldIdStack_[0 .. $ - 1];
}
TField readFieldBegin() {
TField f = void;
f.name = null;
auto bite = readByte();
auto type = cast(CType)(bite & 0x0f);
if (type == CType.STOP) {
// Struct stop byte, nothing more to do.
f.id = 0;
f.type = TType.STOP;
return f;
}
// Mask off the 4 MSB of the type header, which could contain a field id
// delta.
auto modifier = cast(short)((bite & 0xf0) >> 4);
if (modifier > 0) {
f.id = cast(short)(lastFieldId_ + modifier);
} else {
// Delta encoding not used, just read the id as usual.
f.id = readI16();
}
f.type = getTType(type);
if (type == CType.BOOLEAN_TRUE || type == CType.BOOLEAN_FALSE) {
// For boolean fields, the value is encoded in the type keep it around
// for the readBool() call.
hasBoolValue_ = true;
boolValue_ = (type == CType.BOOLEAN_TRUE ? true : false);
}
lastFieldId_ = f.id;
return f;
}
void readFieldEnd() {}
TList readListBegin() {
auto sizeAndType = readByte();
auto lsize = (sizeAndType >> 4) & 0xf;
if (lsize == 0xf) {
lsize = readVarint32();
}
checkSize(lsize, containerSizeLimit);
TList l = void;
l.elemType = getTType(cast(CType)(sizeAndType & 0x0f));
l.size = cast(size_t)lsize;
return l;
}
void readListEnd() {}
TMap readMapBegin() {
TMap m = void;
auto size = readVarint32();
ubyte kvType;
if (size != 0) {
kvType = readByte();
}
checkSize(size, containerSizeLimit);
m.size = size;
m.keyType = getTType(cast(CType)(kvType >> 4));
m.valueType = getTType(cast(CType)(kvType & 0xf));
return m;
}
void readMapEnd() {}
TSet readSetBegin() {
auto sizeAndType = readByte();
auto lsize = (sizeAndType >> 4) & 0xf;
if (lsize == 0xf) {
lsize = readVarint32();
}
checkSize(lsize, containerSizeLimit);
TSet s = void;
s.elemType = getTType(cast(CType)(sizeAndType & 0xf));
s.size = cast(size_t)lsize;
return s;
}
void readSetEnd() {}
private:
void writeFieldBeginInternal(TField field, byte typeOverride = -1) {
// If there's a type override, use that.
auto typeToWrite = (typeOverride == -1 ? toCType(field.type) : typeOverride);
// check if we can use delta encoding for the field id
if (field.id > lastFieldId_ && (field.id - lastFieldId_) <= 15) {
// write them together
writeByte(cast(byte)((field.id - lastFieldId_) << 4 | typeToWrite));
} else {
// write them separate
writeByte(cast(byte)typeToWrite);
writeI16(field.id);
}
lastFieldId_ = field.id;
}
void writeCollectionBegin(TType elemType, size_t size) {
if (size <= 14) {
writeByte(cast(byte)(size << 4 | toCType(elemType)));
} else {
assert(size <= int.max);
writeByte(cast(byte)(0xf0 | toCType(elemType)));
writeVarint32(cast(int)size);
}
}
void writeVarint32(uint n) {
ubyte[5] buf = void;
ubyte wsize;
while (true) {
if ((n & ~0x7F) == 0) {
buf[wsize++] = cast(ubyte)n;
break;
} else {
buf[wsize++] = cast(ubyte)((n & 0x7F) | 0x80);
n >>= 7;
}
}
trans_.write(buf[0 .. wsize]);
}
/*
* Write an i64 as a varint. Results in 1-10 bytes on the wire.
*/
void writeVarint64(ulong n) {
ubyte[10] buf = void;
ubyte wsize;
while (true) {
if ((n & ~0x7FL) == 0) {
buf[wsize++] = cast(ubyte)n;
break;
} else {
buf[wsize++] = cast(ubyte)((n & 0x7F) | 0x80);
n >>= 7;
}
}
trans_.write(buf[0 .. wsize]);
}
/*
* Convert l into a zigzag long. This allows negative numbers to be
* represented compactly as a varint.
*/
ulong i64ToZigzag(long l) {
return (l << 1) ^ (l >> 63);
}
/*
* Convert n into a zigzag int. This allows negative numbers to be
* represented compactly as a varint.
*/
uint i32ToZigzag(int n) {
return (n << 1) ^ (n >> 31);
}
CType toCType(TType type) {
final switch (type) {
case TType.STOP:
return CType.STOP;
case TType.BOOL:
return CType.BOOLEAN_TRUE;
case TType.BYTE:
return CType.BYTE;
case TType.DOUBLE:
return CType.DOUBLE;
case TType.I16:
return CType.I16;
case TType.I32:
return CType.I32;
case TType.I64:
return CType.I64;
case TType.STRING:
return CType.BINARY;
case TType.STRUCT:
return CType.STRUCT;
case TType.MAP:
return CType.MAP;
case TType.SET:
return CType.SET;
case TType.LIST:
return CType.LIST;
case TType.VOID:
assert(false, "Invalid type passed.");
}
}
int readVarint32() {
return cast(int)readVarint64();
}
long readVarint64() {
ulong val;
ubyte shift;
ubyte[10] buf = void; // 64 bits / (7 bits/byte) = 10 bytes.
auto bufSize = buf.sizeof;
auto borrowed = trans_.borrow(buf.ptr, bufSize);
ubyte rsize;
if (borrowed) {
// Fast path.
while (true) {
auto bite = borrowed[rsize];
rsize++;
val |= cast(ulong)(bite & 0x7f) << shift;
shift += 7;
if (!(bite & 0x80)) {
trans_.consume(rsize);
return val;
}
// Have to check for invalid data so we don't crash.
if (rsize == buf.sizeof) {
throw new TProtocolException(TProtocolException.Type.INVALID_DATA,
"Variable-length int over 10 bytes.");
}
}
} else {
// Slow path.
while (true) {
ubyte[1] bite;
trans_.readAll(bite);
++rsize;
val |= cast(ulong)(bite[0] & 0x7f) << shift;
shift += 7;
if (!(bite[0] & 0x80)) {
return val;
}
// Might as well check for invalid data on the slow path too.
if (rsize >= buf.sizeof) {
throw new TProtocolException(TProtocolException.Type.INVALID_DATA,
"Variable-length int over 10 bytes.");
}
}
}
}
/*
* Convert from zigzag int to int.
*/
int zigzagToI32(uint n) {
return (n >> 1) ^ -(n & 1);
}
/*
* Convert from zigzag long to long.
*/
long zigzagToI64(ulong n) {
return (n >> 1) ^ -(n & 1);
}
TType getTType(CType type) {
final switch (type) {
case CType.STOP:
return TType.STOP;
case CType.BOOLEAN_FALSE:
return TType.BOOL;
case CType.BOOLEAN_TRUE:
return TType.BOOL;
case CType.BYTE:
return TType.BYTE;
case CType.I16:
return TType.I16;
case CType.I32:
return TType.I32;
case CType.I64:
return TType.I64;
case CType.DOUBLE:
return TType.DOUBLE;
case CType.BINARY:
return TType.STRING;
case CType.LIST:
return TType.LIST;
case CType.SET:
return TType.SET;
case CType.MAP:
return TType.MAP;
case CType.STRUCT:
return TType.STRUCT;
}
}
void checkSize(int size, int limit) {
if (size < 0) {
throw new TProtocolException(TProtocolException.Type.NEGATIVE_SIZE);
} else if (limit > 0 && size > limit) {
throw new TProtocolException(TProtocolException.Type.SIZE_LIMIT);
}
}
enum PROTOCOL_ID = 0x82;
enum VERSION_N = 1;
enum VERSION_MASK = 0b0001_1111;
enum TYPE_MASK = 0b1110_0000;
enum TYPE_BITS = 0b0000_0111;
enum TYPE_SHIFT_AMOUNT = 5;
// Probably need to implement a better stack at some point.
short[] fieldIdStack_;
short lastFieldId_;
TField booleanField_;
bool hasBoolValue_;
bool boolValue_;
Transport trans_;
}
/**
* TCompactProtocol construction helper to avoid having to explicitly specify
* the transport type, i.e. to allow the constructor being called using IFTI
* (see $(LINK2 http://d.puremagic.com/issues/show_bug.cgi?id=6082, D Bugzilla
* enhancement requet 6082)).
*/
TCompactProtocol!Transport tCompactProtocol(Transport)(Transport trans,
int containerSizeLimit = 0, int stringSizeLimit = 0
) if (isTTransport!Transport)
{
return new TCompactProtocol!Transport(trans,
containerSizeLimit, stringSizeLimit);
}
private {
enum CType : ubyte {
STOP = 0x0,
BOOLEAN_TRUE = 0x1,
BOOLEAN_FALSE = 0x2,
BYTE = 0x3,
I16 = 0x4,
I32 = 0x5,
I64 = 0x6,
DOUBLE = 0x7,
BINARY = 0x8,
LIST = 0x9,
SET = 0xa,
MAP = 0xb,
STRUCT = 0xc
}
static assert(CType.max <= 0xf,
"Compact protocol wire type representation must fit into 4 bits.");
}
unittest {
import std.exception;
import thrift.transport.memory;
// Check the message header format.
auto buf = new TMemoryBuffer;
auto compact = tCompactProtocol(buf);
compact.writeMessageBegin(TMessage("foo", TMessageType.CALL, 0));
auto header = new ubyte[7];
buf.readAll(header);
enforce(header == [
130, // Protocol id.
33, // Version/type byte.
0, // Sequence id.
3, 102, 111, 111 // Method name.
]);
}
unittest {
import thrift.internal.test.protocol;
testContainerSizeLimit!(TCompactProtocol!())();
testStringSizeLimit!(TCompactProtocol!())();
}
/**
* TProtocolFactory creating a TCompactProtocol instance for passed in
* transports.
*
* The optional Transports template tuple parameter can be used to specify
* one or more TTransport implementations to specifically instantiate
* TCompactProtocol for. If the actual transport types encountered at
* runtime match one of the transports in the list, a specialized protocol
* instance is created. Otherwise, a generic TTransport version is used.
*/
class TCompactProtocolFactory(Transports...) if (
allSatisfy!(isTTransport, Transports)
) : TProtocolFactory {
///
this(int containerSizeLimit = 0, int stringSizeLimit = 0) {
containerSizeLimit_ = 0;
stringSizeLimit_ = 0;
}
TProtocol getProtocol(TTransport trans) const {
foreach (Transport; TypeTuple!(Transports, TTransport)) {
auto concreteTrans = cast(Transport)trans;
if (concreteTrans) {
return new TCompactProtocol!Transport(concreteTrans);
}
}
throw new TProtocolException(
"Passed null transport to TCompactProtocolFactory.");
}
int containerSizeLimit_;
int stringSizeLimit_;
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,145 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.protocol.processor;
// Use selective import once DMD @@BUG314@@ is fixed.
import std.variant /+ : Variant +/;
import thrift.protocol.base;
import thrift.transport.base;
/**
* A processor is a generic object which operates upon an input stream and
* writes to some output stream.
*
* The definition of this object is loose, though the typical case is for some
* sort of server that either generates responses to an input stream or
* forwards data from one pipe onto another.
*
* An implementation can optionally allow one or more TProcessorEventHandlers
* to be attached, providing an interface to hook custom code into the
* handling process, which can be used e.g. for gathering statistics.
*/
interface TProcessor {
///
bool process(TProtocol iprot, TProtocol oprot,
Variant connectionContext = Variant()
) in {
assert(iprot);
assert(oprot);
}
///
final bool process(TProtocol prot, Variant connectionContext = Variant()) {
return process(prot, prot, connectionContext);
}
}
/**
* Handles events from a processor.
*/
interface TProcessorEventHandler {
/**
* Called before calling other callback methods.
*
* Expected to return some sort of »call context«, which is passed to all
* other callbacks for that function invocation.
*/
Variant createContext(string methodName, Variant connectionContext);
/**
* Called when handling the method associated with a context has been
* finished can be used to perform clean up work.
*/
void deleteContext(Variant callContext, string methodName);
/**
* Called before reading arguments.
*/
void preRead(Variant callContext, string methodName);
/**
* Called between reading arguments and calling the handler.
*/
void postRead(Variant callContext, string methodName);
/**
* Called between calling the handler and writing the response.
*/
void preWrite(Variant callContext, string methodName);
/**
* Called after writing the response.
*/
void postWrite(Variant callContext, string methodName);
/**
* Called when handling a one-way function call is completed successfully.
*/
void onewayComplete(Variant callContext, string methodName);
/**
* Called if the handler throws an undeclared exception.
*/
void handlerError(Variant callContext, string methodName, Exception e);
}
struct TConnectionInfo {
/// The input and output protocols.
TProtocol input;
TProtocol output; /// Ditto.
/// The underlying transport used for the connection
/// This is the transport that was returned by TServerTransport.accept(),
/// and it may be different than the transport pointed to by the input and
/// output protocols.
TTransport transport;
}
interface TProcessorFactory {
/**
* Get the TProcessor to use for a particular connection.
*
* This method is always invoked in the same thread that the connection was
* accepted on, which is always the same thread for all current server
* implementations.
*/
TProcessor getProcessor(ref const(TConnectionInfo) connInfo);
}
/**
* The default processor factory which always returns the same instance.
*/
class TSingletonProcessorFactory : TProcessorFactory {
/**
* Creates a new instance.
*
* Params:
* processor = The processor object to return from getProcessor().
*/
this(TProcessor processor) {
processor_ = processor;
}
override TProcessor getProcessor(ref const(TConnectionInfo) connInfo) {
return processor_;
}
private:
TProcessor processor_;
}

View file

@ -0,0 +1,147 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.server.base;
import std.variant : Variant;
import thrift.protocol.base;
import thrift.protocol.binary;
import thrift.protocol.processor;
import thrift.server.transport.base;
import thrift.transport.base;
import thrift.util.cancellation;
/**
* Base class for all Thrift servers.
*
* By setting the eventHandler property to a TServerEventHandler
* implementation, custom code can be integrated into the processing pipeline,
* which can be used e.g. for gathering statistics.
*/
class TServer {
/**
* Starts serving.
*
* Blocks until the server finishes, i.e. a serious problem occurred or the
* cancellation request has been triggered.
*
* Server implementations are expected to implement cancellation in a best-
* effort way usually, it should be possible to immediately stop accepting
* connections and return after all currently active clients have been
* processed, but this might not be the case for every conceivable
* implementation.
*/
abstract void serve(TCancellation cancellation = null);
/// The server event handler to notify. Null by default.
TServerEventHandler eventHandler;
protected:
this(
TProcessor processor,
TServerTransport serverTransport,
TTransportFactory transportFactory,
TProtocolFactory protocolFactory
) {
this(processor, serverTransport, transportFactory, transportFactory,
protocolFactory, protocolFactory);
}
this(
TProcessorFactory processorFactory,
TServerTransport serverTransport,
TTransportFactory transportFactory,
TProtocolFactory protocolFactory
) {
this(processorFactory, serverTransport, transportFactory, transportFactory,
protocolFactory, protocolFactory);
}
this(
TProcessor processor,
TServerTransport serverTransport,
TTransportFactory inputTransportFactory,
TTransportFactory outputTransportFactory,
TProtocolFactory inputProtocolFactory,
TProtocolFactory outputProtocolFactory
) {
this(new TSingletonProcessorFactory(processor), serverTransport,
inputTransportFactory, outputTransportFactory,
inputProtocolFactory, outputProtocolFactory);
}
this(
TProcessorFactory processorFactory,
TServerTransport serverTransport,
TTransportFactory inputTransportFactory,
TTransportFactory outputTransportFactory,
TProtocolFactory inputProtocolFactory,
TProtocolFactory outputProtocolFactory
) {
import std.exception;
import thrift.base;
enforce(inputTransportFactory,
new TException("Input transport factory must not be null."));
enforce(outputTransportFactory,
new TException("Output transport factory must not be null."));
enforce(inputProtocolFactory,
new TException("Input protocol factory must not be null."));
enforce(outputProtocolFactory,
new TException("Output protocol factory must not be null."));
processorFactory_ = processorFactory;
serverTransport_ = serverTransport;
inputTransportFactory_ = inputTransportFactory;
outputTransportFactory_ = outputTransportFactory;
inputProtocolFactory_ = inputProtocolFactory;
outputProtocolFactory_ = outputProtocolFactory;
}
TProcessorFactory processorFactory_;
TServerTransport serverTransport_;
TTransportFactory inputTransportFactory_;
TTransportFactory outputTransportFactory_;
TProtocolFactory inputProtocolFactory_;
TProtocolFactory outputProtocolFactory_;
}
/**
* Handles events from a TServer core.
*/
interface TServerEventHandler {
/**
* Called before the server starts accepting connections.
*/
void preServe();
/**
* Called when a new client has connected and processing is about to begin.
*/
Variant createContext(TProtocol input, TProtocol output);
/**
* Called when request handling for a client has been finished can be used
* to perform clean up work.
*/
void deleteContext(Variant serverContext, TProtocol input, TProtocol output);
/**
* Called when the processor for a client call is about to be invoked.
*/
void preProcess(Variant serverContext, TTransport transport);
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,181 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.server.simple;
import std.variant : Variant;
import thrift.base;
import thrift.protocol.base;
import thrift.protocol.processor;
import thrift.server.base;
import thrift.server.transport.base;
import thrift.transport.base;
import thrift.util.cancellation;
/**
* The most basic server.
*
* It is single-threaded and after it accepts a connections, it processes
* requests on it until it closes, then waiting for the next connection.
*
* It is not so much of use in production than it is for writing unittests, or
* as an example on how to provide a custom TServer implementation.
*/
class TSimpleServer : TServer {
///
this(
TProcessor processor,
TServerTransport serverTransport,
TTransportFactory transportFactory,
TProtocolFactory protocolFactory
) {
super(processor, serverTransport, transportFactory, protocolFactory);
}
///
this(
TProcessorFactory processorFactory,
TServerTransport serverTransport,
TTransportFactory transportFactory,
TProtocolFactory protocolFactory
) {
super(processorFactory, serverTransport, transportFactory, protocolFactory);
}
///
this(
TProcessor processor,
TServerTransport serverTransport,
TTransportFactory inputTransportFactory,
TTransportFactory outputTransportFactory,
TProtocolFactory inputProtocolFactory,
TProtocolFactory outputProtocolFactory
) {
super(processor, serverTransport, inputTransportFactory,
outputTransportFactory, inputProtocolFactory, outputProtocolFactory);
}
this(
TProcessorFactory processorFactory,
TServerTransport serverTransport,
TTransportFactory inputTransportFactory,
TTransportFactory outputTransportFactory,
TProtocolFactory inputProtocolFactory,
TProtocolFactory outputProtocolFactory
) {
super(processorFactory, serverTransport, inputTransportFactory,
outputTransportFactory, inputProtocolFactory, outputProtocolFactory);
}
override void serve(TCancellation cancellation = null) {
serverTransport_.listen();
if (eventHandler) eventHandler.preServe();
while (true) {
TTransport client;
TTransport inputTransport;
TTransport outputTransport;
TProtocol inputProtocol;
TProtocol outputProtocol;
try {
client = serverTransport_.accept(cancellation);
scope(failure) client.close();
inputTransport = inputTransportFactory_.getTransport(client);
scope(failure) inputTransport.close();
outputTransport = outputTransportFactory_.getTransport(client);
scope(failure) outputTransport.close();
inputProtocol = inputProtocolFactory_.getProtocol(inputTransport);
outputProtocol = outputProtocolFactory_.getProtocol(outputTransport);
} catch (TCancelledException tcx) {
break;
} catch (TTransportException ttx) {
logError("TServerTransport failed on accept: %s", ttx);
continue;
} catch (TException tx) {
logError("Caught TException on accept: %s", tx);
continue;
}
auto info = TConnectionInfo(inputProtocol, outputProtocol, client);
auto processor = processorFactory_.getProcessor(info);
Variant connectionContext;
if (eventHandler) {
connectionContext =
eventHandler.createContext(inputProtocol, outputProtocol);
}
try {
while (true) {
if (eventHandler) {
eventHandler.preProcess(connectionContext, client);
}
if (!processor.process(inputProtocol, outputProtocol,
connectionContext) || !inputProtocol.transport.peek()
) {
// Something went fundamentlly wrong or there is nothing more to
// process, close the connection.
break;
}
}
} catch (TTransportException ttx) {
logError("Client died: %s", ttx);
} catch (Exception e) {
logError("Uncaught exception: %s", e);
}
if (eventHandler) {
eventHandler.deleteContext(connectionContext, inputProtocol,
outputProtocol);
}
try {
inputTransport.close();
} catch (TTransportException ttx) {
logError("Input close failed: %s", ttx);
}
try {
outputTransport.close();
} catch (TTransportException ttx) {
logError("Output close failed: %s", ttx);
}
try {
client.close();
} catch (TTransportException ttx) {
logError("Client close failed: %s", ttx);
}
}
try {
serverTransport_.close();
} catch (TServerTransportException e) {
logError("Server transport failed to close(): %s", e);
}
}
}
unittest {
import thrift.internal.test.server;
testServeCancel!TSimpleServer();
}

View file

@ -0,0 +1,302 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.server.taskpool;
import core.sync.condition;
import core.sync.mutex;
import std.exception : enforce;
import std.parallelism;
import std.variant : Variant;
import thrift.base;
import thrift.protocol.base;
import thrift.protocol.processor;
import thrift.server.base;
import thrift.server.transport.base;
import thrift.transport.base;
import thrift.util.cancellation;
/**
* A server which dispatches client requests to a std.parallelism TaskPool.
*/
class TTaskPoolServer : TServer {
///
this(
TProcessor processor,
TServerTransport serverTransport,
TTransportFactory transportFactory,
TProtocolFactory protocolFactory,
TaskPool taskPool = null
) {
this(processor, serverTransport, transportFactory, transportFactory,
protocolFactory, protocolFactory, taskPool);
}
///
this(
TProcessorFactory processorFactory,
TServerTransport serverTransport,
TTransportFactory transportFactory,
TProtocolFactory protocolFactory,
TaskPool taskPool = null
) {
this(processorFactory, serverTransport, transportFactory, transportFactory,
protocolFactory, protocolFactory, taskPool);
}
///
this(
TProcessor processor,
TServerTransport serverTransport,
TTransportFactory inputTransportFactory,
TTransportFactory outputTransportFactory,
TProtocolFactory inputProtocolFactory,
TProtocolFactory outputProtocolFactory,
TaskPool taskPool = null
) {
this(new TSingletonProcessorFactory(processor), serverTransport,
inputTransportFactory, outputTransportFactory,
inputProtocolFactory, outputProtocolFactory);
}
///
this(
TProcessorFactory processorFactory,
TServerTransport serverTransport,
TTransportFactory inputTransportFactory,
TTransportFactory outputTransportFactory,
TProtocolFactory inputProtocolFactory,
TProtocolFactory outputProtocolFactory,
TaskPool taskPool = null
) {
super(processorFactory, serverTransport, inputTransportFactory,
outputTransportFactory, inputProtocolFactory, outputProtocolFactory);
if (taskPool) {
this.taskPool = taskPool;
} else {
auto ptp = std.parallelism.taskPool;
if (ptp.size > 0) {
taskPool_ = ptp;
} else {
// If the global task pool is empty (default on a single-core machine),
// create a new one with a single worker thread. The rationale for this
// is to avoid that an application which worked fine with no task pool
// explicitly set on the multi-core developer boxes suddenly fails on a
// single-core user machine.
taskPool_ = new TaskPool(1);
taskPool_.isDaemon = true;
}
}
}
override void serve(TCancellation cancellation = null) {
serverTransport_.listen();
if (eventHandler) eventHandler.preServe();
auto queueState = QueueState();
while (true) {
// Check if we can still handle more connections.
if (maxActiveConns) {
synchronized (queueState.mutex) {
while (queueState.activeConns >= maxActiveConns) {
queueState.connClosed.wait();
}
}
}
TTransport client;
TTransport inputTransport;
TTransport outputTransport;
TProtocol inputProtocol;
TProtocol outputProtocol;
try {
client = serverTransport_.accept(cancellation);
scope(failure) client.close();
inputTransport = inputTransportFactory_.getTransport(client);
scope(failure) inputTransport.close();
outputTransport = outputTransportFactory_.getTransport(client);
scope(failure) outputTransport.close();
inputProtocol = inputProtocolFactory_.getProtocol(inputTransport);
outputProtocol = outputProtocolFactory_.getProtocol(outputTransport);
} catch (TCancelledException tce) {
break;
} catch (TTransportException ttx) {
logError("TServerTransport failed on accept: %s", ttx);
continue;
} catch (TException tx) {
logError("Caught TException on accept: %s", tx);
continue;
}
auto info = TConnectionInfo(inputProtocol, outputProtocol, client);
auto processor = processorFactory_.getProcessor(info);
synchronized (queueState.mutex) {
++queueState.activeConns;
}
taskPool_.put(task!worker(queueState, client, inputProtocol,
outputProtocol, processor, eventHandler));
}
// First, stop accepting new connections.
try {
serverTransport_.close();
} catch (TServerTransportException e) {
logError("Server transport failed to close: %s", e);
}
// Then, wait until all active connections are finished.
synchronized (queueState.mutex) {
while (queueState.activeConns > 0) {
queueState.connClosed.wait();
}
}
}
/**
* Sets the task pool to use.
*
* By default, the global std.parallelism taskPool instance is used, which
* might not be appropriate for many applications, e.g. where tuning the
* number of worker threads is desired. (On single-core systems, a private
* task pool with a single thread is used by default, since the global
* taskPool instance has no worker threads then.)
*
* Note: TTaskPoolServer expects that tasks are never dropped from the pool,
* e.g. by calling TaskPool.close() while there are still tasks in the
* queue. If this happens, serve() will never return.
*/
void taskPool(TaskPool pool) @property {
enforce(pool !is null, "Cannot use a null task pool.");
enforce(pool.size > 0, "Cannot use a task pool with no worker threads.");
taskPool_ = pool;
}
/**
* The maximum number of client connections open at the same time. Zero for
* no limit, which is the default.
*
* If this limit is reached, no clients are accept()ed from the server
* transport any longer until another connection has been closed again.
*/
size_t maxActiveConns;
protected:
TaskPool taskPool_;
}
// Cannot be private as worker has to be passed as alias parameter to
// another module.
// private {
/*
* The state of the »connection queue«, i.e. used for keeping track of how
* many client connections are currently processed.
*/
struct QueueState {
/// Protects the queue state.
Mutex mutex;
/// The number of active connections (from the time they are accept()ed
/// until they are closed when the worked task finishes).
size_t activeConns;
/// Signals that the number of active connections has been decreased, i.e.
/// that a connection has been closed.
Condition connClosed;
/// Returns an initialized instance.
static QueueState opCall() {
QueueState q;
q.mutex = new Mutex;
q.connClosed = new Condition(q.mutex);
return q;
}
}
void worker(ref QueueState queueState, TTransport client,
TProtocol inputProtocol, TProtocol outputProtocol,
TProcessor processor, TServerEventHandler eventHandler)
{
scope (exit) {
synchronized (queueState.mutex) {
assert(queueState.activeConns > 0);
--queueState.activeConns;
queueState.connClosed.notifyAll();
}
}
Variant connectionContext;
if (eventHandler) {
connectionContext =
eventHandler.createContext(inputProtocol, outputProtocol);
}
try {
while (true) {
if (eventHandler) {
eventHandler.preProcess(connectionContext, client);
}
if (!processor.process(inputProtocol, outputProtocol,
connectionContext) || !inputProtocol.transport.peek()
) {
// Something went fundamentlly wrong or there is nothing more to
// process, close the connection.
break;
}
}
} catch (TTransportException ttx) {
logError("Client died: %s", ttx);
} catch (Exception e) {
logError("Uncaught exception: %s", e);
}
if (eventHandler) {
eventHandler.deleteContext(connectionContext, inputProtocol,
outputProtocol);
}
try {
inputProtocol.transport.close();
} catch (TTransportException ttx) {
logError("Input close failed: %s", ttx);
}
try {
outputProtocol.transport.close();
} catch (TTransportException ttx) {
logError("Output close failed: %s", ttx);
}
try {
client.close();
} catch (TTransportException ttx) {
logError("Client close failed: %s", ttx);
}
}
// }
unittest {
import thrift.internal.test.server;
testServeCancel!TTaskPoolServer();
}

View file

@ -0,0 +1,215 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.server.threaded;
import core.thread;
import std.variant : Variant;
import thrift.base;
import thrift.protocol.base;
import thrift.protocol.processor;
import thrift.server.base;
import thrift.server.transport.base;
import thrift.transport.base;
import thrift.util.cancellation;
/**
* A simple threaded server which spawns a new thread per connection.
*/
class TThreadedServer : TServer {
///
this(
TProcessor processor,
TServerTransport serverTransport,
TTransportFactory transportFactory,
TProtocolFactory protocolFactory
) {
super(processor, serverTransport, transportFactory, protocolFactory);
}
///
this(
TProcessorFactory processorFactory,
TServerTransport serverTransport,
TTransportFactory transportFactory,
TProtocolFactory protocolFactory
) {
super(processorFactory, serverTransport, transportFactory, protocolFactory);
}
///
this(
TProcessor processor,
TServerTransport serverTransport,
TTransportFactory inputTransportFactory,
TTransportFactory outputTransportFactory,
TProtocolFactory inputProtocolFactory,
TProtocolFactory outputProtocolFactory
) {
super(processor, serverTransport, inputTransportFactory,
outputTransportFactory, inputProtocolFactory, outputProtocolFactory);
}
///
this(
TProcessorFactory processorFactory,
TServerTransport serverTransport,
TTransportFactory inputTransportFactory,
TTransportFactory outputTransportFactory,
TProtocolFactory inputProtocolFactory,
TProtocolFactory outputProtocolFactory
) {
super(processorFactory, serverTransport, inputTransportFactory,
outputTransportFactory, inputProtocolFactory, outputProtocolFactory);
}
override void serve(TCancellation cancellation = null) {
try {
// Start the server listening
serverTransport_.listen();
} catch (TTransportException ttx) {
logError("listen() failed: %s", ttx);
return;
}
if (eventHandler) eventHandler.preServe();
auto workerThreads = new ThreadGroup();
while (true) {
TTransport client;
TTransport inputTransport;
TTransport outputTransport;
TProtocol inputProtocol;
TProtocol outputProtocol;
try {
client = serverTransport_.accept(cancellation);
scope(failure) client.close();
inputTransport = inputTransportFactory_.getTransport(client);
scope(failure) inputTransport.close();
outputTransport = outputTransportFactory_.getTransport(client);
scope(failure) outputTransport.close();
inputProtocol = inputProtocolFactory_.getProtocol(inputTransport);
outputProtocol = outputProtocolFactory_.getProtocol(outputTransport);
} catch (TCancelledException tce) {
break;
} catch (TTransportException ttx) {
logError("TServerTransport failed on accept: %s", ttx);
continue;
} catch (TException tx) {
logError("Caught TException on accept: %s", tx);
continue;
}
auto info = TConnectionInfo(inputProtocol, outputProtocol, client);
auto processor = processorFactory_.getProcessor(info);
auto worker = new WorkerThread(client, inputProtocol, outputProtocol,
processor, eventHandler);
workerThreads.add(worker);
worker.start();
}
try {
serverTransport_.close();
} catch (TServerTransportException e) {
logError("Server transport failed to close: %s", e);
}
workerThreads.joinAll();
}
}
// The worker thread handling a client connection.
private class WorkerThread : Thread {
this(TTransport client, TProtocol inputProtocol, TProtocol outputProtocol,
TProcessor processor, TServerEventHandler eventHandler)
{
client_ = client;
inputProtocol_ = inputProtocol;
outputProtocol_ = outputProtocol;
processor_ = processor;
eventHandler_ = eventHandler;
super(&run);
}
void run() {
Variant connectionContext;
if (eventHandler_) {
connectionContext =
eventHandler_.createContext(inputProtocol_, outputProtocol_);
}
try {
while (true) {
if (eventHandler_) {
eventHandler_.preProcess(connectionContext, client_);
}
if (!processor_.process(inputProtocol_, outputProtocol_,
connectionContext) || !inputProtocol_.transport.peek()
) {
// Something went fundamentlly wrong or there is nothing more to
// process, close the connection.
break;
}
}
} catch (TTransportException ttx) {
logError("Client died: %s", ttx);
} catch (Exception e) {
logError("Uncaught exception: %s", e);
}
if (eventHandler_) {
eventHandler_.deleteContext(connectionContext, inputProtocol_,
outputProtocol_);
}
try {
inputProtocol_.transport.close();
} catch (TTransportException ttx) {
logError("Input close failed: %s", ttx);
}
try {
outputProtocol_.transport.close();
} catch (TTransportException ttx) {
logError("Output close failed: %s", ttx);
}
try {
client_.close();
} catch (TTransportException ttx) {
logError("Client close failed: %s", ttx);
}
}
private:
TTransport client_;
TProtocol inputProtocol_;
TProtocol outputProtocol_;
TProcessor processor_;
TServerEventHandler eventHandler_;
}
unittest {
import thrift.internal.test.server;
testServeCancel!TThreadedServer();
}

View file

@ -0,0 +1,133 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.server.transport.base;
import thrift.base;
import thrift.transport.base;
import thrift.util.cancellation;
/**
* Some kind of I/O device enabling servers to listen for incoming client
* connections and communicate with them via a TTransport interface.
*/
interface TServerTransport {
/**
* Starts listening for server connections.
*
* Just as simliar functions commonly found in socket libraries, this
* function does not block.
*
* If the socket is already listening, nothing happens.
*
* Throws: TServerTransportException if listening failed or the transport
* was already listening.
*/
void listen();
/**
* Closes the server transport, causing it to stop listening.
*
* Throws: TServerTransportException if the transport was not listening.
*/
void close();
/**
* Returns whether the server transport is currently listening.
*/
bool isListening() @property;
/**
* Accepts a client connection and returns an opened TTransport for it,
* never returning null.
*
* Blocks until a client connection is available.
*
* Params:
* cancellation = If triggered, requests the call to stop blocking and
* return with a TCancelledException. Implementations are free to
* ignore this if they cannot provide a reasonable.
*
* Throws: TServerTransportException if accepting failed,
* TCancelledException if it was cancelled.
*/
TTransport accept(TCancellation cancellation = null) out (result) {
assert(result !is null);
}
}
/**
* Server transport exception.
*/
class TServerTransportException : TException {
/**
* Error codes for the various types of exceptions.
*/
enum Type {
///
UNKNOWN,
/// The server socket is not listening, but excepted to be.
NOT_LISTENING,
/// The server socket is already listening, but expected not to be.
ALREADY_LISTENING,
/// An operation on the primary underlying resource, e.g. a socket used
/// for accepting connections, failed.
RESOURCE_FAILED
}
///
this(Type type, string file = __FILE__, size_t line = __LINE__, Throwable next = null) {
string msg = "TTransportException: ";
switch (type) {
case Type.UNKNOWN: msg ~= "Unknown server transport exception"; break;
case Type.NOT_LISTENING: msg ~= "Server transport not listening"; break;
case Type.ALREADY_LISTENING: msg ~= "Server transport already listening"; break;
case Type.RESOURCE_FAILED: msg ~= "An underlying resource failed"; break;
default: msg ~= "(Invalid exception type)"; break;
}
this(msg, type, file, line, next);
}
///
this(string msg, string file = __FILE__, size_t line = __LINE__,
Throwable next = null)
{
this(msg, Type.UNKNOWN, file, line, next);
}
///
this(string msg, Type type, string file = __FILE__, size_t line = __LINE__,
Throwable next = null)
{
super(msg, file, line, next);
type_ = type;
}
///
Type type() const nothrow @property {
return type_;
}
protected:
Type type_;
}

View file

@ -0,0 +1,380 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.server.transport.socket;
import core.thread : dur, Duration, Thread;
import core.stdc.string : strerror;
import std.array : empty;
import std.conv : text, to;
import std.exception : enforce;
import std.socket;
import thrift.base;
import thrift.internal.socket;
import thrift.server.transport.base;
import thrift.transport.base;
import thrift.transport.socket;
import thrift.util.awaitable;
import thrift.util.cancellation;
private alias TServerTransportException STE;
/**
* Server socket implementation of TServerTransport.
*
* Maps to std.socket listen()/accept(); only provides TCP/IP sockets (i.e. no
* Unix sockets) for now, because they are not supported in std.socket.
*/
class TServerSocket : TServerTransport {
/**
* Constructs a new instance.
*
* Params:
* port = The TCP port to listen at (host is always 0.0.0.0).
* sendTimeout = The socket sending timeout.
* recvTimout = The socket receiving timeout.
*/
this(ushort port, Duration sendTimeout = dur!"hnsecs"(0),
Duration recvTimeout = dur!"hnsecs"(0))
{
port_ = port;
sendTimeout_ = sendTimeout;
recvTimeout_ = recvTimeout;
cancellationNotifier_ = new TSocketNotifier;
socketSet_ = new SocketSet;
}
/// The port the server socket listens at.
ushort port() const @property {
return port_;
}
/// The socket sending timeout, zero to block infinitely.
void sendTimeout(Duration sendTimeout) @property {
sendTimeout_ = sendTimeout;
}
/// The socket receiving timeout, zero to block infinitely.
void recvTimeout(Duration recvTimeout) @property {
recvTimeout_ = recvTimeout;
}
/// The maximum number of listening retries if it fails.
void retryLimit(ushort retryLimit) @property {
retryLimit_ = retryLimit;
}
/// The delay between a listening attempt failing and retrying it.
void retryDelay(Duration retryDelay) @property {
retryDelay_ = retryDelay;
}
/// The size of the TCP send buffer, in bytes.
void tcpSendBuffer(int tcpSendBuffer) @property {
tcpSendBuffer_ = tcpSendBuffer;
}
/// The size of the TCP receiving buffer, in bytes.
void tcpRecvBuffer(int tcpRecvBuffer) @property {
tcpRecvBuffer_ = tcpRecvBuffer;
}
/// Whether to listen on IPv6 only, if IPv6 support is detected
/// (default: false).
void ipv6Only(bool value) @property {
ipv6Only_ = value;
}
override void listen() {
enforce(!isListening, new STE(STE.Type.ALREADY_LISTENING));
serverSocket_ = makeSocketAndListen(port_, ACCEPT_BACKLOG, retryLimit_,
retryDelay_, tcpSendBuffer_, tcpRecvBuffer_, ipv6Only_);
}
override void close() {
enforce(isListening, new STE(STE.Type.NOT_LISTENING));
serverSocket_.shutdown(SocketShutdown.BOTH);
serverSocket_.close();
serverSocket_ = null;
}
override bool isListening() @property {
return serverSocket_ !is null;
}
/// Number of connections listen() backlogs.
enum ACCEPT_BACKLOG = 1024;
override TTransport accept(TCancellation cancellation = null) {
enforce(isListening, new STE(STE.Type.NOT_LISTENING));
if (cancellation) cancellationNotifier_.attach(cancellation.triggering);
scope (exit) if (cancellation) cancellationNotifier_.detach();
// Too many EINTRs is a fault condition and would need to be handled
// manually by our caller, but we can tolerate a certain number.
enum MAX_EINTRS = 10;
uint numEintrs;
while (true) {
socketSet_.reset();
socketSet_.add(serverSocket_);
socketSet_.add(cancellationNotifier_.socket);
auto ret = Socket.select(socketSet_, null, null);
enforce(ret != 0, new STE("Socket.select() returned 0.",
STE.Type.RESOURCE_FAILED));
if (ret < 0) {
// Select itself failed, check if it was just due to an interrupted
// syscall.
if (getSocketErrno() == INTERRUPTED_ERRNO) {
if (numEintrs++ < MAX_EINTRS) {
continue;
} else {
throw new STE("Socket.select() was interrupted by a signal (EINTR) " ~
"more than " ~ to!string(MAX_EINTRS) ~ " times.",
STE.Type.RESOURCE_FAILED
);
}
}
throw new STE("Unknown error on Socket.select(): " ~
socketErrnoString(getSocketErrno()), STE.Type.RESOURCE_FAILED);
} else {
// Check for a ping on the interrupt socket.
if (socketSet_.isSet(cancellationNotifier_.socket)) {
cancellation.throwIfTriggered();
}
// Check for the actual server socket having a connection waiting.
if (socketSet_.isSet(serverSocket_)) {
break;
}
}
}
try {
auto client = createTSocket(serverSocket_.accept());
client.sendTimeout = sendTimeout_;
client.recvTimeout = recvTimeout_;
return client;
} catch (SocketException e) {
throw new STE("Unknown error on accepting: " ~ to!string(e),
STE.Type.RESOURCE_FAILED);
}
}
protected:
/**
* Allows derived classes to create a different TSocket type.
*/
TSocket createTSocket(Socket socket) {
return new TSocket(socket);
}
private:
ushort port_;
Duration sendTimeout_;
Duration recvTimeout_;
ushort retryLimit_;
Duration retryDelay_;
uint tcpSendBuffer_;
uint tcpRecvBuffer_;
bool ipv6Only_;
Socket serverSocket_;
TSocketNotifier cancellationNotifier_;
// Keep socket set between accept() calls to avoid reallocating.
SocketSet socketSet_;
}
Socket makeSocketAndListen(ushort port, int backlog, ushort retryLimit,
Duration retryDelay, uint tcpSendBuffer = 0, uint tcpRecvBuffer = 0,
bool ipv6Only = false
) {
Address localAddr;
try {
// null represents the wildcard address.
auto addrInfos = getAddressInfo(null, to!string(port),
AddressInfoFlags.PASSIVE, SocketType.STREAM, ProtocolType.TCP);
foreach (i, ai; addrInfos) {
// Prefer to bind to IPv6 addresses, because then IPv4 is listened to as
// well, but not the other way round.
if (ai.family == AddressFamily.INET6 || i == (addrInfos.length - 1)) {
localAddr = ai.address;
break;
}
}
} catch (Exception e) {
throw new STE("Could not determine local address to listen on.",
STE.Type.RESOURCE_FAILED, __FILE__, __LINE__, e);
}
Socket socket;
try {
socket = new Socket(localAddr.addressFamily, SocketType.STREAM,
ProtocolType.TCP);
} catch (SocketException e) {
throw new STE("Could not create accepting socket: " ~ to!string(e),
STE.Type.RESOURCE_FAILED);
}
try {
socket.setOption(SocketOptionLevel.IPV6, SocketOption.IPV6_V6ONLY, ipv6Only);
} catch (SocketException e) {
// This is somewhat expected on older systems (e.g. pre-Vista Windows),
// which do not support the IPV6_V6ONLY flag yet. Racy flag just to avoid
// log spew in unit tests.
shared static warned = false;
if (!warned) {
logError("Could not set IPV6_V6ONLY socket option: %s", e);
warned = true;
}
}
alias SocketOptionLevel.SOCKET lvlSock;
// Prevent 2 maximum segement lifetime delay on accept.
try {
socket.setOption(lvlSock, SocketOption.REUSEADDR, true);
} catch (SocketException e) {
throw new STE("Could not set REUSEADDR socket option: " ~ to!string(e),
STE.Type.RESOURCE_FAILED);
}
// Set TCP buffer sizes.
if (tcpSendBuffer > 0) {
try {
socket.setOption(lvlSock, SocketOption.SNDBUF, tcpSendBuffer);
} catch (SocketException e) {
throw new STE("Could not set socket send buffer size: " ~ to!string(e),
STE.Type.RESOURCE_FAILED);
}
}
if (tcpRecvBuffer > 0) {
try {
socket.setOption(lvlSock, SocketOption.RCVBUF, tcpRecvBuffer);
} catch (SocketException e) {
throw new STE("Could not set receive send buffer size: " ~ to!string(e),
STE.Type.RESOURCE_FAILED);
}
}
// Turn linger off to avoid blocking on socket close.
try {
Linger l;
l.on = 0;
l.time = 0;
socket.setOption(lvlSock, SocketOption.LINGER, l);
} catch (SocketException e) {
throw new STE("Could not disable socket linger: " ~ to!string(e),
STE.Type.RESOURCE_FAILED);
}
// Set TCP_NODELAY.
try {
socket.setOption(SocketOptionLevel.TCP, SocketOption.TCP_NODELAY, true);
} catch (SocketException e) {
throw new STE("Could not disable Nagle's algorithm: " ~ to!string(e),
STE.Type.RESOURCE_FAILED);
}
ushort retries;
while (true) {
try {
socket.bind(localAddr);
break;
} catch (SocketException) {}
// If bind() worked, we breaked outside the loop above.
retries++;
if (retries < retryLimit) {
Thread.sleep(retryDelay);
} else {
throw new STE(text("Could not bind to address: ", localAddr),
STE.Type.RESOURCE_FAILED);
}
}
socket.listen(backlog);
return socket;
}
unittest {
// Test interrupt().
{
auto sock = new TServerSocket(0);
sock.listen();
scope (exit) sock.close();
auto cancellation = new TCancellationOrigin;
auto intThread = new Thread({
// Sleep for a bit until the socket is accepting.
Thread.sleep(dur!"msecs"(50));
cancellation.trigger();
});
intThread.start();
import std.exception;
assertThrown!TCancelledException(sock.accept(cancellation));
}
// Test receive() timeout on accepted client sockets.
{
immutable port = 11122;
auto timeout = dur!"msecs"(500);
auto serverSock = new TServerSocket(port, timeout, timeout);
serverSock.listen();
scope (exit) serverSock.close();
auto clientSock = new TSocket("127.0.0.1", port);
clientSock.open();
scope (exit) clientSock.close();
shared bool hasTimedOut;
auto recvThread = new Thread({
auto sock = serverSock.accept();
ubyte[1] data;
try {
sock.read(data);
} catch (TTransportException e) {
if (e.type == TTransportException.Type.TIMED_OUT) {
hasTimedOut = true;
} else {
import std.stdio;
stderr.writeln(e);
}
}
});
recvThread.isDaemon = true;
recvThread.start();
// Wait for the timeout, with a little bit of spare time.
Thread.sleep(timeout + dur!"msecs"(50));
enforce(hasTimedOut,
"Client socket receive() blocked for longer than recvTimeout.");
}
}

View file

@ -0,0 +1,88 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.server.transport.ssl;
import std.datetime : Duration;
import std.exception : enforce;
import std.socket : Socket;
import thrift.server.transport.socket;
import thrift.transport.base;
import thrift.transport.socket;
import thrift.transport.ssl;
/**
* A server transport implementation using SSL-encrypted sockets.
*
* Note:
* On Posix systems which do not have the BSD-specific SO_NOSIGPIPE flag, you
* might want to ignore the SIGPIPE signal, as OpenSSL might try to write to
* a closed socket if the peer disconnects abruptly:
* ---
* import core.stdc.signal;
* import core.sys.posix.signal;
* signal(SIGPIPE, SIG_IGN);
* ---
*
* See: thrift.transport.ssl.
*/
class TSSLServerSocket : TServerSocket {
/**
* Creates a new TSSLServerSocket.
*
* Params:
* port = The port on which to listen.
* sslContext = The TSSLContext to use for creating client
* sockets. Must be in server-side mode.
*/
this(ushort port, TSSLContext sslContext) {
super(port);
setSSLContext(sslContext);
}
/**
* Creates a new TSSLServerSocket.
*
* Params:
* port = The port on which to listen.
* sendTimeout = The send timeout to set on the client sockets.
* recvTimeout = The receive timeout to set on the client sockets.
* sslContext = The TSSLContext to use for creating client
* sockets. Must be in server-side mode.
*/
this(ushort port, Duration sendTimeout, Duration recvTimeout,
TSSLContext sslContext)
{
super(port, sendTimeout, recvTimeout);
setSSLContext(sslContext);
}
protected:
override TSocket createTSocket(Socket socket) {
return new TSSLSocket(sslContext_, socket);
}
private:
void setSSLContext(TSSLContext sslContext) {
enforce(sslContext.serverSide, new TTransportException(
"Need server-side SSL socket factory for TSSLServerSocket"));
sslContext_ = sslContext;
}
TSSLContext sslContext_;
}

View file

@ -0,0 +1,370 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.transport.base;
import core.stdc.string : strerror;
import std.conv : text;
import thrift.base;
/**
* An entity data can be read from and/or written to.
*
* A TTransport implementation may capable of either reading or writing, but
* not necessarily both.
*/
interface TTransport {
/**
* Whether this transport is open.
*
* If a transport is closed, it can be opened by calling open(), and vice
* versa for close().
*
* While a transport should always be open when trying to read/write data,
* the related functions do not necessarily fail when called for a closed
* transport. Situations like this could occur e.g. with a wrapper
* transport which buffers data when the underlying transport has already
* been closed (possibly because the connection was abruptly closed), but
* there is still data left to be read in the buffers. This choice has been
* made to simplify transport implementations, in terms of both code
* complexity and runtime overhead.
*/
bool isOpen() @property;
/**
* Tests whether there is more data to read or if the remote side is
* still open.
*
* A typical use case would be a server checking if it should process
* another request on the transport.
*/
bool peek();
/**
* Opens the transport for communications.
*
* If the transport is already open, nothing happens.
*
* Throws: TTransportException if opening fails.
*/
void open();
/**
* Closes the transport.
*
* If the transport is not open, nothing happens.
*
* Throws: TTransportException if closing fails.
*/
void close();
/**
* Attempts to fill the given buffer by reading data.
*
* For potentially blocking data sources (e.g. sockets), read() will only
* block if no data is available at all. If there is some data available,
* but waiting for new data to arrive would be required to fill the whole
* buffer, the readily available data will be immediately returned use
* readAll() if you want to wait until the whole buffer is filled.
*
* Params:
* buf = Slice to use as buffer.
*
* Returns: How many bytes were actually read
*
* Throws: TTransportException if an error occurs.
*/
size_t read(ubyte[] buf);
/**
* Fills the given buffer by reading data into it, failing if not enough
* data is available.
*
* Params:
* buf = Slice to use as buffer.
*
* Throws: TTransportException if insufficient data is available or reading
* fails altogether.
*/
void readAll(ubyte[] buf);
/**
* Must be called by clients when read is completed.
*
* Implementations can choose to perform a transport-specific action, e.g.
* logging the request to a file.
*
* Returns: The number of bytes read if available, 0 otherwise.
*/
size_t readEnd();
/**
* Writes the passed slice of data.
*
* Note: You must call flush() to ensure the data is actually written,
* and available to be read back in the future. Destroying a TTransport
* object does not automatically flush pending data if you destroy a
* TTransport object with written but unflushed data, that data may be
* discarded.
*
* Params:
* buf = Slice of data to write.
*
* Throws: TTransportException if an error occurs.
*/
void write(in ubyte[] buf);
/**
* Must be called by clients when write is completed.
*
* Implementations can choose to perform a transport-specific action, e.g.
* logging the request to a file.
*
* Returns: The number of bytes written if available, 0 otherwise.
*/
size_t writeEnd();
/**
* Flushes any pending data to be written.
*
* Must be called before destruction to ensure writes are actually complete,
* otherwise pending data may be discarded. Typically used with buffered
* transport mechanisms.
*
* Throws: TTransportException if an error occurs.
*/
void flush();
/**
* Attempts to return a slice of <code>len</code> bytes of incoming data,
* possibly copied into buf, not consuming them (i.e.: a later read will
* return the same data).
*
* This method is meant to support protocols that need to read variable-
* length fields. They can attempt to borrow the maximum amount of data that
* they will need, then <code>consume()</code> what they actually use. Some
* transports will not support this method and others will fail occasionally,
* so protocols must be prepared to fall back to <code>read()</code> if
* borrow fails.
*
* The transport must be open when calling this.
*
* Params:
* buf = A buffer where the data can be stored if needed, or null to
* indicate that the caller is not supplying storage, but would like a
* slice of an internal buffer, if available.
* len = The number of bytes to borrow.
*
* Returns: If the borrow succeeds, a slice containing the borrowed data,
* null otherwise. The slice will be at least as long as requested, but
* may be longer if the returned slice points into an internal buffer
* rather than buf.
*
* Throws: TTransportException if an error occurs.
*/
const(ubyte)[] borrow(ubyte* buf, size_t len) out (result) {
// FIXME: Commented out because len gets corrupted in
// thrift.transport.memory borrow() unittest.
version(none) assert(result is null || result.length >= len,
"Buffer returned by borrow() too short.");
}
/**
* Remove len bytes from the transport. This must always follow a borrow
* of at least len bytes, and should always succeed.
*
* The transport must be open when calling this.
*
* Params:
* len = Number of bytes to consume.
*
* Throws: TTransportException if an error occurs.
*/
void consume(size_t len);
}
/**
* Provides basic fall-back implementations of the TTransport interface.
*/
class TBaseTransport : TTransport {
override bool isOpen() @property {
return false;
}
override bool peek() {
return isOpen;
}
override void open() {
throw new TTransportException("Cannot open TBaseTransport.",
TTransportException.Type.NOT_IMPLEMENTED);
}
override void close() {
throw new TTransportException("Cannot close TBaseTransport.",
TTransportException.Type.NOT_IMPLEMENTED);
}
override size_t read(ubyte[] buf) {
throw new TTransportException("Cannot read from a TBaseTransport.",
TTransportException.Type.NOT_IMPLEMENTED);
}
override void readAll(ubyte[] buf) {
size_t have;
while (have < buf.length) {
size_t get = read(buf[have..$]);
if (get <= 0) {
throw new TTransportException(text("Could not readAll() ", buf.length,
" bytes as no more data was available after ", have, " bytes."),
TTransportException.Type.END_OF_FILE);
}
have += get;
}
}
override size_t readEnd() {
// Do nothing by default, not needed by all implementations.
return 0;
}
override void write(in ubyte[] buf) {
throw new TTransportException("Cannot write to a TBaseTransport.",
TTransportException.Type.NOT_IMPLEMENTED);
}
override size_t writeEnd() {
// Do nothing by default, not needed by all implementations.
return 0;
}
override void flush() {
// Do nothing by default, not needed by all implementations.
}
override const(ubyte)[] borrow(ubyte* buf, size_t len) {
// borrow() is allowed to fail anyway, so just return null.
return null;
}
override void consume(size_t len) {
throw new TTransportException("Cannot consume from a TBaseTransport.",
TTransportException.Type.NOT_IMPLEMENTED);
}
protected:
this() {}
}
/**
* Makes a TTransport which wraps a given source transport in some way.
*
* A common use case is inside server implementations, where the raw client
* connections accepted from e.g. TServerSocket need to be wrapped into
* buffered or compressed transports.
*/
class TTransportFactory {
/**
* Default implementation does nothing, just returns the transport given.
*/
TTransport getTransport(TTransport trans) {
return trans;
}
}
/**
* Transport factory for transports which simply wrap an underlying TTransport
* without requiring additional configuration.
*/
class TWrapperTransportFactory(T) if (
is(T : TTransport) && __traits(compiles, new T(TTransport.init))
) : TTransportFactory {
override T getTransport(TTransport trans) {
return new T(trans);
}
}
/**
* Transport-level exception.
*/
class TTransportException : TException {
/**
* Error codes for the various types of exceptions.
*/
enum Type {
UNKNOWN, ///
NOT_OPEN, ///
TIMED_OUT, ///
END_OF_FILE, ///
INTERRUPTED, ///
BAD_ARGS, ///
CORRUPTED_DATA, ///
INTERNAL_ERROR, ///
NOT_IMPLEMENTED ///
}
///
this(Type type, string file = __FILE__, size_t line = __LINE__, Throwable next = null) {
static string msgForType(Type type) {
switch (type) {
case Type.UNKNOWN: return "Unknown transport exception";
case Type.NOT_OPEN: return "Transport not open";
case Type.TIMED_OUT: return "Timed out";
case Type.END_OF_FILE: return "End of file";
case Type.INTERRUPTED: return "Interrupted";
case Type.BAD_ARGS: return "Invalid arguments";
case Type.CORRUPTED_DATA: return "Corrupted Data";
case Type.INTERNAL_ERROR: return "Internal error";
case Type.NOT_IMPLEMENTED: return "Not implemented";
default: return "(Invalid exception type)";
}
}
this(msgForType(type), type, file, line, next);
}
///
this(string msg, string file = __FILE__, size_t line = __LINE__,
Throwable next = null)
{
this(msg, Type.UNKNOWN, file, line, next);
}
///
this(string msg, Type type, string file = __FILE__, size_t line = __LINE__,
Throwable next = null)
{
super(msg, file, line, next);
type_ = type;
}
///
Type type() const nothrow @property {
return type_;
}
protected:
Type type_;
}
/**
* Meta-programming helper returning whether the passed type is a TTransport
* implementation.
*/
template isTTransport(T) {
enum isTTransport = is(T : TTransport);
}

View file

@ -0,0 +1,215 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.transport.buffered;
import std.algorithm : min;
import std.array : empty;
import std.exception : enforce;
import thrift.transport.base;
/**
* Wraps another transport and buffers reads and writes until the internal
* buffers are exhausted, at which point new data is fetched resp. the
* accumulated data is written out at once.
*/
final class TBufferedTransport : TBaseTransport {
/**
* Constructs a new instance, using the default buffer sizes.
*
* Params:
* transport = The underlying transport to wrap.
*/
this(TTransport transport) {
this(transport, DEFAULT_BUFFER_SIZE);
}
/**
* Constructs a new instance, using the specified buffer size.
*
* Params:
* transport = The underlying transport to wrap.
* bufferSize = The size of the read and write buffers to use, in bytes.
*/
this(TTransport transport, size_t bufferSize) {
this(transport, bufferSize, bufferSize);
}
/**
* Constructs a new instance, using the specified buffer size.
*
* Params:
* transport = The underlying transport to wrap.
* readBufferSize = The size of the read buffer to use, in bytes.
* writeBufferSize = The size of the write buffer to use, in bytes.
*/
this(TTransport transport, size_t readBufferSize, size_t writeBufferSize) {
transport_ = transport;
readBuffer_ = new ubyte[readBufferSize];
writeBuffer_ = new ubyte[writeBufferSize];
writeAvail_ = writeBuffer_;
}
/// The default size of the read/write buffers, in bytes.
enum int DEFAULT_BUFFER_SIZE = 512;
override bool isOpen() @property {
return transport_.isOpen();
}
override bool peek() {
if (readAvail_.empty) {
// If there is nothing available to read, see if we can get something
// from the underlying transport.
auto bytesRead = transport_.read(readBuffer_);
readAvail_ = readBuffer_[0 .. bytesRead];
}
return !readAvail_.empty;
}
override void open() {
transport_.open();
}
override void close() {
if (!isOpen) return;
flush();
transport_.close();
}
override size_t read(ubyte[] buf) {
if (readAvail_.empty) {
// No data left in our buffer, fetch some from the underlying transport.
if (buf.length > readBuffer_.length) {
// If the amount of data requested is larger than our reading buffer,
// directly read to the passed buffer. This probably doesn't occur too
// often in practice (and even if it does, the underlying transport
// probably cannot fulfill the request at once anyway), but it can't
// harm to try…
return transport_.read(buf);
}
auto bytesRead = transport_.read(readBuffer_);
readAvail_ = readBuffer_[0 .. bytesRead];
}
// Hand over whatever we have.
auto give = min(readAvail_.length, buf.length);
buf[0 .. give] = readAvail_[0 .. give];
readAvail_ = readAvail_[give .. $];
return give;
}
/**
* Shortcut version of readAll.
*/
override void readAll(ubyte[] buf) {
if (readAvail_.length >= buf.length) {
buf[] = readAvail_[0 .. buf.length];
readAvail_ = readAvail_[buf.length .. $];
return;
}
super.readAll(buf);
}
override void write(in ubyte[] buf) {
if (writeAvail_.length >= buf.length) {
// If the data fits in the buffer, just save it there.
writeAvail_[0 .. buf.length] = buf;
writeAvail_ = writeAvail_[buf.length .. $];
return;
}
// We have to decide if we copy data from buf to our internal buffer, or
// just directly write them out. The same considerations about avoiding
// syscalls as for C++ apply here.
auto bytesAvail = writeAvail_.ptr - writeBuffer_.ptr;
if ((bytesAvail + buf.length >= 2 * writeBuffer_.length) || (bytesAvail == 0)) {
// We would immediately need two syscalls anyway (or we don't have
// anything) in our buffer to write, so just write out both buffers.
if (bytesAvail > 0) {
transport_.write(writeBuffer_[0 .. bytesAvail]);
writeAvail_ = writeBuffer_;
}
transport_.write(buf);
return;
}
// Fill up our internal buffer for a write.
writeAvail_[] = buf[0 .. writeAvail_.length];
auto left = buf[writeAvail_.length .. $];
transport_.write(writeBuffer_);
// Copy the rest into our buffer.
writeBuffer_[0 .. left.length] = left[];
writeAvail_ = writeBuffer_[left.length .. $];
}
override void flush() {
// Write out any data waiting in the write buffer.
auto bytesAvail = writeAvail_.ptr - writeBuffer_.ptr;
if (bytesAvail > 0) {
// Note that we reset writeAvail_ prior to calling the underlying protocol
// to make sure the buffer is cleared even if the transport throws an
// exception.
writeAvail_ = writeBuffer_;
transport_.write(writeBuffer_[0 .. bytesAvail]);
}
// Flush the underlying transport.
transport_.flush();
}
override const(ubyte)[] borrow(ubyte* buf, size_t len) {
if (len <= readAvail_.length) {
return readAvail_;
}
return null;
}
override void consume(size_t len) {
enforce(len <= readBuffer_.length, new TTransportException(
"Invalid consume length.", TTransportException.Type.BAD_ARGS));
readAvail_ = readAvail_[len .. $];
}
/**
* The wrapped transport.
*/
TTransport underlyingTransport() @property {
return transport_;
}
private:
TTransport transport_;
ubyte[] readBuffer_;
ubyte[] writeBuffer_;
ubyte[] readAvail_;
ubyte[] writeAvail_;
}
/**
* Wraps given transports into TBufferedTransports.
*/
alias TWrapperTransportFactory!TBufferedTransport TBufferedTransportFactory;

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,334 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.transport.framed;
import core.bitop : bswap;
import std.algorithm : min;
import std.array : empty;
import std.exception : enforce;
import thrift.transport.base;
/**
* Framed transport.
*
* All writes go into an in-memory buffer until flush is called, at which point
* the transport writes the length of the entire binary chunk followed by the
* data payload. The receiver on the other end then performs a single
* »fixed-length« read to get the whole message off the wire.
*/
final class TFramedTransport : TBaseTransport {
/**
* Constructs a new framed transport.
*
* Params:
* transport = The underlying transport to wrap.
*/
this(TTransport transport) {
transport_ = transport;
}
/**
* Returns the wrapped transport.
*/
TTransport underlyingTransport() @property {
return transport_;
}
override bool isOpen() @property {
return transport_.isOpen;
}
override bool peek() {
return rBuf_.length > 0 || transport_.peek();
}
override void open() {
transport_.open();
}
override void close() {
flush();
transport_.close();
}
/**
* Attempts to read data into the given buffer, stopping when the buffer is
* exhausted or the frame end is reached.
*
* TODO: Contrary to the C++ implementation, this never does cross-frame
* reads is there actually a valid use case for that?
*
* Params:
* buf = Slice to use as buffer.
*
* Returns: How many bytes were actually read.
*
* Throws: TTransportException if an error occurs.
*/
override size_t read(ubyte[] buf) {
// If the buffer is empty, read a new frame off the wire.
if (rBuf_.empty) {
bool gotFrame = readFrame();
if (!gotFrame) return 0;
}
auto size = min(rBuf_.length, buf.length);
buf[0..size] = rBuf_[0..size];
rBuf_ = rBuf_[size..$];
return size;
}
override void write(in ubyte[] buf) {
wBuf_ ~= buf;
}
override void flush() {
if (wBuf_.empty) return;
// Properly reset the write buffer even some of the protocol operations go
// wrong.
scope (exit) {
wBuf_.length = 0;
wBuf_.assumeSafeAppend();
}
int len = bswap(cast(int)wBuf_.length);
transport_.write(cast(ubyte[])(&len)[0..1]);
transport_.write(wBuf_);
transport_.flush();
}
override const(ubyte)[] borrow(ubyte* buf, size_t len) {
if (len <= rBuf_.length) {
return rBuf_;
} else {
// Don't try attempting cross-frame borrows, trying that does not make
// much sense anyway.
return null;
}
}
override void consume(size_t len) {
enforce(len <= rBuf_.length, new TTransportException(
"Invalid consume length", TTransportException.Type.BAD_ARGS));
rBuf_ = rBuf_[len .. $];
}
private:
bool readFrame() {
// Read the size of the next frame. We can't use readAll() since that
// always throws an exception on EOF, but want to throw an exception only
// if EOF occurs after partial size data.
int size;
size_t size_read;
while (size_read < size.sizeof) {
auto data = (cast(ubyte*)&size)[size_read..size.sizeof];
auto read = transport_.read(data);
if (read == 0) {
if (size_read == 0) {
// EOF before any data was read.
return false;
} else {
// EOF after a partial frame header illegal.
throw new TTransportException(
"No more data to read after partial frame header",
TTransportException.Type.END_OF_FILE
);
}
}
size_read += read;
}
size = bswap(size);
enforce(size >= 0, new TTransportException("Frame size has negative value",
TTransportException.Type.CORRUPTED_DATA));
// TODO: Benchmark this.
rBuf_.length = size;
rBuf_.assumeSafeAppend();
transport_.readAll(rBuf_);
return true;
}
TTransport transport_;
ubyte[] rBuf_;
ubyte[] wBuf_;
}
/**
* Wraps given transports into TFramedTransports.
*/
alias TWrapperTransportFactory!TFramedTransport TFramedTransportFactory;
version (unittest) {
import std.random : Mt19937, uniform;
import thrift.transport.memory;
}
// Some basic random testing, always starting with the same seed for
// deterministic unit test results more tests in transport_test.
unittest {
auto randGen = Mt19937(42);
// 32 kiB of data to work with.
auto data = new ubyte[1 << 15];
foreach (ref b; data) {
b = uniform!"[]"(cast(ubyte)0, cast(ubyte)255, randGen);
}
// Generate a list of chunk sizes to split the data into. A uniform
// distribution is not quite realistic, but std.random doesn't have anything
// else yet.
enum MAX_FRAME_LENGTH = 512;
auto chunkSizesList = new size_t[][2];
foreach (ref chunkSizes; chunkSizesList) {
size_t sum;
while (true) {
auto curLen = uniform(0, MAX_FRAME_LENGTH, randGen);
sum += curLen;
if (sum > data.length) break;
chunkSizes ~= curLen;
}
}
chunkSizesList ~= [data.length]; // Also test whole chunk at once.
// Test writing data.
{
foreach (chunkSizes; chunkSizesList) {
auto buf = new TMemoryBuffer;
auto framed = new TFramedTransport(buf);
auto remainingData = data;
foreach (chunkSize; chunkSizes) {
framed.write(remainingData[0..chunkSize]);
remainingData = remainingData[chunkSize..$];
}
framed.flush();
auto writtenData = data[0..($ - remainingData.length)];
auto actualData = buf.getContents();
// Check frame size.
int frameSize = bswap((cast(int[])(actualData[0..int.sizeof]))[0]);
enforce(frameSize == writtenData.length);
// Check actual data.
enforce(actualData[int.sizeof..$] == writtenData);
}
}
// Test reading data.
{
foreach (chunkSizes; chunkSizesList) {
auto buf = new TMemoryBuffer;
auto size = bswap(cast(int)data.length);
buf.write(cast(ubyte[])(&size)[0..1]);
buf.write(data);
auto framed = new TFramedTransport(buf);
ubyte[] readData;
readData.reserve(data.length);
foreach (chunkSize; chunkSizes) {
// This should work with read because we have one huge frame.
auto oldReadLen = readData.length;
readData.length += chunkSize;
framed.read(readData[oldReadLen..$]);
}
enforce(readData == data[0..readData.length]);
}
}
// Test combined reading/writing of multiple frames.
foreach (flushProbability; [1, 2, 4, 8, 16, 32]) {
foreach (chunkSizes; chunkSizesList) {
auto buf = new TMemoryBuffer;
auto framed = new TFramedTransport(buf);
size_t[] frameSizes;
// Write the data.
size_t frameSize;
auto remainingData = data;
foreach (chunkSize; chunkSizes) {
framed.write(remainingData[0..chunkSize]);
remainingData = remainingData[chunkSize..$];
frameSize += chunkSize;
if (frameSize > 0 && uniform(0, flushProbability, randGen) == 0) {
frameSizes ~= frameSize;
frameSize = 0;
framed.flush();
}
}
if (frameSize > 0) {
frameSizes ~= frameSize;
frameSize = 0;
framed.flush();
}
// Read it back.
auto readData = new ubyte[data.length - remainingData.length];
auto remainToRead = readData;
foreach (fSize; frameSizes) {
// We are exploiting an implementation detail of TFramedTransport:
// The read buffer starts empty and it will never return more than one
// frame per read, so by just requesting all of the data, we should
// always get exactly one frame.
auto got = framed.read(remainToRead);
enforce(got == fSize);
remainToRead = remainToRead[fSize..$];
}
enforce(remainToRead.empty);
enforce(readData == data[0..readData.length]);
}
}
}
// Test flush()ing an empty buffer.
unittest {
auto buf = new TMemoryBuffer();
auto framed = new TFramedTransport(buf);
immutable out1 = [0, 0, 0, 1, 'a'];
immutable out2 = [0, 0, 0, 1, 'a', 0, 0, 0, 2, 'b', 'c'];
framed.flush();
enforce(buf.getContents() == []);
framed.flush();
framed.flush();
enforce(buf.getContents() == []);
framed.write(cast(ubyte[])"a");
enforce(buf.getContents() == []);
framed.flush();
enforce(buf.getContents() == out1);
framed.flush();
framed.flush();
enforce(buf.getContents() == out1);
framed.write(cast(ubyte[])"bc");
enforce(buf.getContents() == out1);
framed.flush();
enforce(buf.getContents() == out2);
framed.flush();
framed.flush();
enforce(buf.getContents() == out2);
}

View file

@ -0,0 +1,459 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/**
* HTTP tranpsort implementation, modelled after the C++ one.
*
* Unfortunately, libcurl is quite heavyweight and supports only client-side
* applications. This is an implementation of the basic HTTP/1.1 parts
* supporting HTTP 100 Continue, chunked transfer encoding, keepalive, etc.
*/
module thrift.transport.http;
import std.algorithm : canFind, countUntil, endsWith, findSplit, min, startsWith;
import std.ascii : toLower;
import std.array : empty;
import std.conv : parse, to;
import std.datetime : Clock, UTC;
import std.string : stripLeft;
import thrift.base : VERSION;
import thrift.transport.base;
import thrift.transport.memory;
import thrift.transport.socket;
/**
* Base class for both client- and server-side HTTP transports.
*/
abstract class THttpTransport : TBaseTransport {
this(TTransport transport) {
transport_ = transport;
readHeaders_ = true;
httpBuf_ = new ubyte[HTTP_BUFFER_SIZE];
httpBufRemaining_ = httpBuf_[0 .. 0];
readBuffer_ = new TMemoryBuffer;
writeBuffer_ = new TMemoryBuffer;
}
override bool isOpen() {
return transport_.isOpen();
}
override bool peek() {
return transport_.peek();
}
override void open() {
transport_.open();
}
override void close() {
transport_.close();
}
override size_t read(ubyte[] buf) {
if (!readBuffer_.peek()) {
readBuffer_.reset();
if (!refill()) return 0;
if (readHeaders_) {
readHeaders();
}
size_t got;
if (chunked_) {
got = readChunked();
} else {
got = readContent(contentLength_);
}
readHeaders_ = true;
if (got == 0) return 0;
}
return readBuffer_.read(buf);
}
override size_t readEnd() {
// Read any pending chunked data (footers etc.)
if (chunked_) {
while (!chunkedDone_) {
readChunked();
}
}
return 0;
}
override void write(in ubyte[] buf) {
writeBuffer_.write(buf);
}
override void flush() {
auto data = writeBuffer_.getContents();
string header = getHeader(data.length);
transport_.write(cast(const(ubyte)[]) header);
transport_.write(data);
transport_.flush();
// Reset the buffer and header variables.
writeBuffer_.reset();
readHeaders_ = true;
}
/**
* The size of the buffer to read HTTP requests into, in bytes. Will expand
* as required.
*/
enum HTTP_BUFFER_SIZE = 1024;
protected:
abstract string getHeader(size_t dataLength);
abstract bool parseStatusLine(const(ubyte)[] status);
void parseHeader(const(ubyte)[] header) {
auto split = findSplit(header, [':']);
if (split[1].empty) {
// No colon found.
return;
}
static bool compToLower(ubyte a, ubyte b) {
return toLower(cast(char)a) == toLower(cast(char)b);
}
if (startsWith!compToLower(split[0], cast(ubyte[])"transfer-encoding")) {
if (endsWith!compToLower(split[2], cast(ubyte[])"chunked")) {
chunked_ = true;
}
} else if (startsWith!compToLower(split[0], cast(ubyte[])"content-length")) {
chunked_ = false;
auto lengthString = stripLeft(cast(const(char)[])split[2]);
contentLength_ = parse!size_t(lengthString);
}
}
private:
ubyte[] readLine() {
while (true) {
auto split = findSplit(httpBufRemaining_, cast(ubyte[])"\r\n");
if (split[1].empty) {
// No CRLF yet, move whatever we have now to front and refill.
if (httpBufRemaining_.empty) {
httpBufRemaining_ = httpBuf_[0 .. 0];
} else {
httpBuf_[0 .. httpBufRemaining_.length] = httpBufRemaining_;
httpBufRemaining_ = httpBuf_[0 .. httpBufRemaining_.length];
}
if (!refill()) {
auto buf = httpBufRemaining_;
httpBufRemaining_ = httpBufRemaining_[$ - 1 .. $ - 1];
return buf;
}
} else {
// Set the remaining buffer to the part after \r\n and return the part
// (line) before it.
httpBufRemaining_ = split[2];
return split[0];
}
}
}
void readHeaders() {
// Initialize headers state variables
contentLength_ = 0;
chunked_ = false;
chunkedDone_ = false;
chunkSize_ = 0;
// Control state flow
bool statusLine = true;
bool finished;
// Loop until headers are finished
while (true) {
auto line = readLine();
if (line.length == 0) {
if (finished) {
readHeaders_ = false;
return;
} else {
// Must have been an HTTP 100, keep going for another status line
statusLine = true;
}
} else {
if (statusLine) {
statusLine = false;
finished = parseStatusLine(line);
} else {
parseHeader(line);
}
}
}
}
size_t readChunked() {
size_t length;
auto line = readLine();
size_t chunkSize;
try {
auto charLine = cast(char[])line;
chunkSize = parse!size_t(charLine, 16);
} catch (Exception e) {
throw new TTransportException("Invalid chunk size: " ~ to!string(line),
TTransportException.Type.CORRUPTED_DATA);
}
if (chunkSize == 0) {
readChunkedFooters();
} else {
// Read data content
length += readContent(chunkSize);
// Read trailing CRLF after content
readLine();
}
return length;
}
void readChunkedFooters() {
while (true) {
auto line = readLine();
if (line.length == 0) {
chunkedDone_ = true;
break;
}
}
}
size_t readContent(size_t size) {
auto need = size;
while (need > 0) {
if (httpBufRemaining_.length == 0) {
// We have given all the data, reset position to head of the buffer.
httpBufRemaining_ = httpBuf_[0 .. 0];
if (!refill()) return size - need;
}
auto give = min(httpBufRemaining_.length, need);
readBuffer_.write(cast(ubyte[])httpBufRemaining_[0 .. give]);
httpBufRemaining_ = httpBufRemaining_[give .. $];
need -= give;
}
return size;
}
bool refill() {
// Is there a nicer way to do this?
auto indexBegin = httpBufRemaining_.ptr - httpBuf_.ptr;
auto indexEnd = indexBegin + httpBufRemaining_.length;
if (httpBuf_.length - indexEnd <= (httpBuf_.length / 4)) {
httpBuf_.length *= 2;
}
// Read more data.
auto got = transport_.read(cast(ubyte[])httpBuf_[indexEnd .. $]);
if (got == 0) return false;
httpBufRemaining_ = httpBuf_[indexBegin .. indexEnd + got];
return true;
}
TTransport transport_;
TMemoryBuffer writeBuffer_;
TMemoryBuffer readBuffer_;
bool readHeaders_;
bool chunked_;
bool chunkedDone_;
size_t chunkSize_;
size_t contentLength_;
ubyte[] httpBuf_;
ubyte[] httpBufRemaining_;
}
/**
* HTTP client transport.
*/
final class TClientHttpTransport : THttpTransport {
/**
* Constructs a client http transport operating on the passed underlying
* transport.
*
* Params:
* transport = The underlying transport used for the actual I/O.
* host = The HTTP host string.
* path = The HTTP path string.
*/
this(TTransport transport, string host, string path) {
super(transport);
host_ = host;
path_ = path;
}
/**
* Convenience overload for constructing a client HTTP transport using a
* TSocket connecting to the specified host and port.
*
* Params:
* host = The server to connect to, also used as HTTP host string.
* port = The port to connect to.
* path = The HTTP path string.
*/
this(string host, ushort port, string path) {
this(new TSocket(host, port), host, path);
}
protected:
override string getHeader(size_t dataLength) {
return "POST " ~ path_ ~ " HTTP/1.1\r\n" ~
"Host: " ~ host_ ~ "\r\n" ~
"Content-Type: application/x-thrift\r\n" ~
"Content-Length: " ~ to!string(dataLength) ~ "\r\n" ~
"Accept: application/x-thrift\r\n"
"User-Agent: Thrift/" ~ VERSION ~ " (D/TClientHttpTransport)\r\n" ~
"\r\n";
}
override bool parseStatusLine(const(ubyte)[] status) {
// HTTP-Version SP Status-Code SP Reason-Phrase CRLF
auto firstSplit = findSplit(status, [' ']);
if (firstSplit[1].empty) {
throw new TTransportException("Bad status: " ~ to!string(status),
TTransportException.Type.CORRUPTED_DATA);
}
auto codeReason = firstSplit[2][countUntil!"a != b"(firstSplit[2], ' ') .. $];
auto secondSplit = findSplit(codeReason, [' ']);
if (secondSplit[1].empty) {
throw new TTransportException("Bad status: " ~ to!string(status),
TTransportException.Type.CORRUPTED_DATA);
}
if (secondSplit[0] == "200") {
// HTTP 200 = OK, we got the response
return true;
} else if (secondSplit[0] == "100") {
// HTTP 100 = continue, just keep reading
return false;
}
throw new TTransportException("Bad status (unhandled status code): " ~
to!string(cast(const(char[]))status), TTransportException.Type.CORRUPTED_DATA);
}
private:
string host_;
string path_;
}
/**
* HTTP server transport.
*/
final class TServerHttpTransport : THttpTransport {
/**
* Constructs a new instance.
*
* Param:
* transport = The underlying transport used for the actual I/O.
*/
this(TTransport transport) {
super(transport);
}
protected:
override string getHeader(size_t dataLength) {
return "HTTP/1.1 200 OK\r\n" ~
"Date: " ~ getRFC1123Time() ~ "\r\n" ~
"Server: Thrift/" ~ VERSION ~ "\r\n" ~
"Content-Type: application/x-thrift\r\n" ~
"Content-Length: " ~ to!string(dataLength) ~ "\r\n" ~
"Connection: Keep-Alive\r\n" ~
"\r\n";
}
override bool parseStatusLine(const(ubyte)[] status) {
// Method SP Request-URI SP HTTP-Version CRLF.
auto split = findSplit(status, [' ']);
if (split[1].empty) {
throw new TTransportException("Bad status: " ~ to!string(status),
TTransportException.Type.CORRUPTED_DATA);
}
auto uriVersion = split[2][countUntil!"a != b"(split[2], ' ') .. $];
if (!canFind(uriVersion, ' ')) {
throw new TTransportException("Bad status: " ~ to!string(status),
TTransportException.Type.CORRUPTED_DATA);
}
if (split[0] == "POST") {
// POST method ok, looking for content.
return true;
}
throw new TTransportException("Bad status (unsupported method): " ~
to!string(status), TTransportException.Type.CORRUPTED_DATA);
}
}
/**
* Wraps a transport into a HTTP server protocol.
*/
alias TWrapperTransportFactory!TServerHttpTransport TServerHttpTransportFactory;
private {
import std.string : format;
string getRFC1123Time() {
auto sysTime = Clock.currTime(UTC());
auto dayName = capMemberName(sysTime.dayOfWeek);
auto monthName = capMemberName(sysTime.month);
return format("%s, %s %s %s %s:%s:%s GMT", dayName, sysTime.day,
monthName, sysTime.year, sysTime.hour, sysTime.minute, sysTime.second);
}
import std.ascii : toUpper;
import std.traits : EnumMembers;
string capMemberName(T)(T val) if (is(T == enum)) {
foreach (i, e; EnumMembers!T) {
enum name = __traits(derivedMembers, T)[i];
enum capName = cast(char) toUpper(name[0]) ~ name [1 .. $];
if (val == e) {
return capName;
}
}
throw new Exception("Not a member of " ~ T.stringof ~ ": " ~ to!string(val));
}
unittest {
enum Foo {
bar,
bAZ
}
import std.exception;
enforce(capMemberName(Foo.bar) == "Bar");
enforce(capMemberName(Foo.bAZ) == "BAZ");
}
}

View file

@ -0,0 +1,233 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.transport.memory;
import core.exception : onOutOfMemoryError;
import core.stdc.stdlib : free, realloc;
import std.algorithm : min;
import std.conv : text;
import thrift.transport.base;
/**
* A transport that simply reads from and writes to an in-memory buffer. Every
* time you call write on it, the data is simply placed into a buffer, and
* every time you call read, data is consumed from that buffer.
*
* Currently, the storage for written data is never reclaimed, even if the
* buffer contents have already been read out again.
*/
final class TMemoryBuffer : TBaseTransport {
/**
* Constructs a new memory transport with an empty internal buffer.
*/
this() {}
/**
* Constructs a new memory transport with an empty internal buffer,
* reserving space for capacity bytes in advance.
*
* If the amount of data which will be written to the buffer is already
* known on construction, this can better performance over the default
* constructor because reallocations can be avoided.
*
* If the preallocated buffer is exhausted, data can still be written to the
* transport, but reallocations will happen.
*
* Params:
* capacity = Size of the initially reserved buffer (in bytes).
*/
this(size_t capacity) {
reset(capacity);
}
/**
* Constructs a new memory transport initially containing the passed data.
*
* For now, the passed buffer is not intelligently used, the data is just
* copied to the internal buffer.
*
* Params:
* buffer = Initial contents available to be read.
*/
this(in ubyte[] contents) {
auto size = contents.length;
reset(size);
buffer_[0 .. size] = contents[];
writeOffset_ = size;
}
/**
* Destructor, frees the internally allocated buffer.
*/
~this() {
free(buffer_);
}
/**
* Returns a read-only view of the current buffer contents.
*
* Note: For performance reasons, the returned slice is only valid for the
* life of this object, and may be invalidated on the next write() call at
* will you might want to immediately .dup it if you intend to keep it
* around.
*/
const(ubyte)[] getContents() {
return buffer_[readOffset_ .. writeOffset_];
}
/**
* A memory transport is always open.
*/
override bool isOpen() @property {
return true;
}
override bool peek() {
return writeOffset_ - readOffset_ > 0;
}
/**
* Opening is a no-op() for a memory buffer.
*/
override void open() {}
/**
* Closing is a no-op() for a memory buffer, it is always open.
*/
override void close() {}
override size_t read(ubyte[] buf) {
auto size = min(buf.length, writeOffset_ - readOffset_);
buf[0 .. size] = buffer_[readOffset_ .. readOffset_ + size];
readOffset_ += size;
return size;
}
/**
* Shortcut version of readAll() using this over TBaseTransport.readAll()
* can give us a nice speed increase because gives us a nice speed increase
* because it is typically a very hot path during deserialization.
*/
override void readAll(ubyte[] buf) {
auto available = writeOffset_ - readOffset_;
if (buf.length > available) {
throw new TTransportException(text("Cannot readAll() ", buf.length,
" bytes of data because only ", available, " bytes are available."),
TTransportException.Type.END_OF_FILE);
}
buf[] = buffer_[readOffset_ .. readOffset_ + buf.length];
readOffset_ += buf.length;
}
override void write(in ubyte[] buf) {
auto need = buf.length;
if (bufferLen_ - writeOffset_ < need) {
// Exponential growth.
auto newLen = bufferLen_ + 1;
while (newLen - writeOffset_ < need) newLen *= 2;
cRealloc(buffer_, newLen);
bufferLen_ = newLen;
}
buffer_[writeOffset_ .. writeOffset_ + need] = buf[];
writeOffset_ += need;
}
override const(ubyte)[] borrow(ubyte* buf, size_t len) {
if (len <= writeOffset_ - readOffset_) {
return buffer_[readOffset_ .. writeOffset_];
} else {
return null;
}
}
override void consume(size_t len) {
readOffset_ += len;
}
void reset() {
readOffset_ = 0;
writeOffset_ = 0;
}
void reset(size_t capacity) {
readOffset_ = 0;
writeOffset_ = 0;
if (bufferLen_ < capacity) {
cRealloc(buffer_, capacity);
bufferLen_ = capacity;
}
}
private:
ubyte* buffer_;
size_t bufferLen_;
size_t readOffset_;
size_t writeOffset_;
}
private {
void cRealloc(ref ubyte* data, size_t newSize) {
auto result = realloc(data, newSize);
if (result is null) onOutOfMemoryError();
data = cast(ubyte*)result;
}
}
version (unittest) {
import std.exception;
}
unittest {
auto a = new TMemoryBuffer(5);
immutable(ubyte[]) testData = [1, 2, 3, 4];
auto buf = new ubyte[testData.length];
enforce(a.isOpen);
// a should be empty.
enforce(!a.peek());
enforce(a.read(buf) == 0);
assertThrown!TTransportException(a.readAll(buf));
// Write some data and read it back again.
a.write(testData);
enforce(a.peek());
enforce(a.getContents() == testData);
enforce(a.read(buf) == testData.length);
enforce(buf == testData);
// a should be empty again.
enforce(!a.peek());
enforce(a.read(buf) == 0);
assertThrown!TTransportException(a.readAll(buf));
// Test the constructor which directly accepts initial data.
auto b = new TMemoryBuffer(testData);
enforce(b.isOpen);
enforce(b.peek());
enforce(b.getContents() == testData);
// Test borrow().
auto borrowed = b.borrow(null, testData.length);
enforce(borrowed == testData);
enforce(b.peek());
b.consume(testData.length);
enforce(!b.peek());
}

View file

@ -0,0 +1,219 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.transport.piped;
import thrift.transport.base;
import thrift.transport.memory;
/**
* Pipes data request from one transport to another when readEnd()
* or writeEnd() is called.
*
* A typical use case would be to log requests on e.g. a socket to
* disk (i. e. pipe them to a TFileWriterTransport).
*
* The implementation keeps an internal buffer which expands to
* hold the whole amount of data read/written until the corresponding *End()
* method is called.
*
* Contrary to the C++ implementation, this doesn't introduce yet another layer
* of input/output buffering, all calls are passed to the underlying source
* transport verbatim.
*/
final class TPipedTransport(Source = TTransport) if (
isTTransport!Source
) : TBaseTransport {
/// The default initial buffer size if not explicitly specified, in bytes.
enum DEFAULT_INITIAL_BUFFER_SIZE = 512;
/**
* Constructs a new instance.
*
* By default, only reads are piped (pipeReads = true, pipeWrites = false).
*
* Params:
* srcTrans = The transport to which all requests are forwarded.
* dstTrans = The transport the read/written data is copied to.
* initialBufferSize = The default size of the read/write buffers, for
* performance tuning.
*/
this(Source srcTrans, TTransport dstTrans,
size_t initialBufferSize = DEFAULT_INITIAL_BUFFER_SIZE
) {
srcTrans_ = srcTrans;
dstTrans_ = dstTrans;
readBuffer_ = new TMemoryBuffer(initialBufferSize);
writeBuffer_ = new TMemoryBuffer(initialBufferSize);
pipeReads_ = true;
pipeWrites_ = false;
}
bool pipeReads() @property const {
return pipeReads_;
}
void pipeReads(bool value) @property {
if (!value) {
readBuffer_.reset();
}
pipeReads_ = value;
}
bool pipeWrites() @property const {
return pipeWrites_;
}
void pipeWrites(bool value) @property {
if (!value) {
writeBuffer_.reset();
}
pipeWrites_ = value;
}
override bool isOpen() {
return srcTrans_.isOpen();
}
override bool peek() {
return srcTrans_.peek();
}
override void open() {
srcTrans_.open();
}
override void close() {
srcTrans_.close();
}
override size_t read(ubyte[] buf) {
auto bytesRead = srcTrans_.read(buf);
if (pipeReads_) {
readBuffer_.write(buf[0 .. bytesRead]);
}
return bytesRead;
}
override size_t readEnd() {
if (pipeReads_) {
auto data = readBuffer_.getContents();
dstTrans_.write(data);
dstTrans_.flush();
readBuffer_.reset();
srcTrans_.readEnd();
// Return data.length instead of the readEnd() result of the source
// transports because it might not be available from it.
return data.length;
}
return srcTrans_.readEnd();
}
override void write(in ubyte[] buf) {
if (pipeWrites_) {
writeBuffer_.write(buf);
}
srcTrans_.write(buf);
}
override size_t writeEnd() {
if (pipeWrites_) {
auto data = writeBuffer_.getContents();
dstTrans_.write(data);
dstTrans_.flush();
writeBuffer_.reset();
srcTrans_.writeEnd();
// Return data.length instead of the readEnd() result of the source
// transports because it might not be available from it.
return data.length;
}
return srcTrans_.writeEnd();
}
override void flush() {
srcTrans_.flush();
}
private:
Source srcTrans_;
TTransport dstTrans_;
TMemoryBuffer readBuffer_;
TMemoryBuffer writeBuffer_;
bool pipeReads_;
bool pipeWrites_;
}
/**
* TPipedTransport construction helper to avoid having to explicitly
* specify the transport types, i.e. to allow the constructor being called
* using IFTI (see $(DMDBUG 6082, D Bugzilla enhancement request 6082)).
*/
TPipedTransport!Source tPipedTransport(Source)(
Source srcTrans, TTransport dstTrans
) if (isTTransport!Source) {
return new typeof(return)(srcTrans, dstTrans);
}
version (unittest) {
// DMD @@BUG@@: UFCS for std.array.empty doesn't work when import is moved
// into unittest block.
import std.array;
import std.exception : enforce;
}
unittest {
auto underlying = new TMemoryBuffer;
auto pipeTarget = new TMemoryBuffer;
auto trans = tPipedTransport(underlying, pipeTarget);
underlying.write(cast(ubyte[])"abcd");
ubyte[4] buffer;
trans.readAll(buffer[0 .. 2]);
enforce(buffer[0 .. 2] == "ab");
enforce(pipeTarget.getContents().empty);
trans.readEnd();
enforce(pipeTarget.getContents() == "ab");
pipeTarget.reset();
underlying.write(cast(ubyte[])"ef");
trans.readAll(buffer[0 .. 2]);
enforce(buffer[0 .. 2] == "cd");
enforce(pipeTarget.getContents().empty);
trans.readAll(buffer[0 .. 2]);
enforce(buffer[0 .. 2] == "ef");
enforce(pipeTarget.getContents().empty);
trans.readEnd();
enforce(pipeTarget.getContents() == "cdef");
}

View file

@ -0,0 +1,147 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/**
* Transports which operate on generic D ranges.
*/
module thrift.transport.range;
import std.array : empty;
import std.range;
import std.traits : Unqual;
import thrift.transport.base;
/**
* Adapts an ubyte input range for reading via the TTransport interface.
*
* The case where R is a plain ubyte[] is reasonably optimized, so a possible
* use case for TInputRangeTransport would be to deserialize some data held in
* a memory buffer.
*/
final class TInputRangeTransport(R) if (
isInputRange!(Unqual!R) && is(ElementType!R : const(ubyte))
) : TBaseTransport {
/**
* Constructs a new instance.
*
* Params:
* data = The input range to use as data.
*/
this(R data) {
data_ = data;
}
/**
* An input range transport is always open.
*/
override bool isOpen() @property {
return true;
}
override bool peek() {
return !data_.empty;
}
/**
* Opening is a no-op() for an input range transport.
*/
override void open() {}
/**
* Closing is a no-op() for a memory buffer.
*/
override void close() {}
override size_t read(ubyte[] buf) {
auto data = data_.take(buf.length);
auto bytes = data.length;
static if (is(typeof(R.init[1 .. 2]) : const(ubyte)[])) {
// put() is currently unnecessarily slow if both ranges are sliceable.
buf[0 .. bytes] = data[];
data_ = data_[bytes .. $];
} else {
buf.put(data);
}
return bytes;
}
/**
* Shortcut version of readAll() for slicable ranges.
*
* Because readAll() is typically a very hot path during deserialization,
* using this over TBaseTransport.readAll() gives us a nice increase in
* speed due to the reduced amount of indirections.
*/
override void readAll(ubyte[] buf) {
static if (is(typeof(R.init[1 .. 2]) : const(ubyte)[])) {
if (buf.length <= data_.length) {
buf[] = data_[0 .. buf.length];
data_ = data_[buf.length .. $];
return;
}
}
super.readAll(buf);
}
override const(ubyte)[] borrow(ubyte* buf, size_t len) {
static if (is(R : const(ubyte)[])) {
// Can only borrow if our data type is actually an ubyte array.
if (len <= data_.length) {
return data_;
}
}
return null;
}
override void consume(size_t len) {
static if (is(R : const(ubyte)[])) {
if (len > data_.length) {
throw new TTransportException("Invalid consume length",
TTransportException.Type.BAD_ARGS);
}
data_ = data_[len .. $];
} else {
super.consume(len);
}
}
/**
* Sets a new data range to use.
*/
void reset(R data) {
data_ = data;
}
private:
R data_;
}
/**
* TInputRangeTransport construction helper to avoid having to explicitly
* specify the argument type, i.e. to allow the constructor being called using
* IFTI (see $(LINK2 http://d.puremagic.com/issues/show_bug.cgi?id=6082, D
* Bugzilla enhancement requet 6082)).
*/
TInputRangeTransport!R tInputRangeTransport(R)(R data) if (
is (TInputRangeTransport!R)
) {
return new TInputRangeTransport!R(data);
}

View file

@ -0,0 +1,453 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.transport.socket;
import core.thread : Thread;
import core.time : dur, Duration;
import std.array : empty;
import std.conv : text, to;
import std.exception : enforce;
import std.socket;
import thrift.base;
import thrift.transport.base;
import thrift.internal.socket;
/**
* Common parts of a socket TTransport implementation, regardless of how the
* actual I/O is performed (sync/async).
*/
abstract class TSocketBase : TBaseTransport {
/**
* Constructor that takes an already created, connected (!) socket.
*
* Params:
* socket = Already created, connected socket object.
*/
this(Socket socket) {
socket_ = socket;
setSocketOpts();
}
/**
* Creates a new unconnected socket that will connect to the given host
* on the given port.
*
* Params:
* host = Remote host.
* port = Remote port.
*/
this(string host, ushort port) {
host_ = host;
port_ = port;
}
/**
* Checks whether the socket is connected.
*/
override bool isOpen() @property {
return socket_ !is null;
}
/**
* Writes as much data to the socket as there can be in a single OS call.
*
* Params:
* buf = Data to write.
*
* Returns: The actual number of bytes written. Never more than buf.length.
*/
abstract size_t writeSome(in ubyte[] buf) out (written) {
// DMD @@BUG@@: Enabling this e.g. fails the contract in the
// async_test_server, because buf.length evaluates to 0 here, even though
// in the method body it correctly is 27 (equal to the return value).
version (none) assert(written <= buf.length, text("Implementation wrote " ~
"more data than requested to?! (", written, " vs. ", buf.length, ")"));
} body {
assert(0, "DMD bug? Why would contracts work for interfaces, but not "
"for abstract methods? "
"(Error: function […] in and out contracts require function body");
}
/**
* Returns the actual address of the peer the socket is connected to.
*
* In contrast, the host and port properties contain the address used to
* establish the connection, and are not updated after the connection.
*
* The socket must be open when calling this.
*/
Address getPeerAddress() {
enforce(isOpen, new TTransportException("Cannot get peer host for " ~
"closed socket.", TTransportException.Type.NOT_OPEN));
if (!peerAddress_) {
peerAddress_ = socket_.remoteAddress();
assert(peerAddress_);
}
return peerAddress_;
}
/**
* The host the socket is connected to or will connect to. Null if an
* already connected socket was used to construct the object.
*/
string host() const @property {
return host_;
}
/**
* The port the socket is connected to or will connect to. Zero if an
* already connected socket was used to construct the object.
*/
ushort port() const @property {
return port_;
}
/// The socket send timeout.
Duration sendTimeout() const @property {
return sendTimeout_;
}
/// Ditto
void sendTimeout(Duration value) @property {
sendTimeout_ = value;
}
/// The socket receiving timeout. Values smaller than 500 ms are not
/// supported on Windows.
Duration recvTimeout() const @property {
return recvTimeout_;
}
/// Ditto
void recvTimeout(Duration value) @property {
recvTimeout_ = value;
}
/**
* Returns the OS handle of the underlying socket.
*
* Should not usually be used directly, but access to it can be necessary
* to interface with C libraries.
*/
typeof(socket_.handle()) socketHandle() @property {
return socket_.handle();
}
protected:
/**
* Sets the needed socket options.
*/
void setSocketOpts() {
try {
alias SocketOptionLevel.SOCKET lvlSock;
Linger l;
l.on = 0;
l.time = 0;
socket_.setOption(lvlSock, SocketOption.LINGER, l);
} catch (SocketException e) {
logError("Could not set socket option: %s", e);
}
// Just try to disable Nagle's algorithm this will fail if we are passed
// in a non-TCP socket via the Socket-accepting constructor.
try {
socket_.setOption(SocketOptionLevel.TCP, SocketOption.TCP_NODELAY, true);
} catch (SocketException e) {}
}
/// Remote host.
string host_;
/// Remote port.
ushort port_;
/// Timeout for sending.
Duration sendTimeout_;
/// Timeout for receiving.
Duration recvTimeout_;
/// Cached peer address.
Address peerAddress_;
/// Cached peer host name.
string peerHost_;
/// Cached peer port.
ushort peerPort_;
/// Wrapped socket object.
Socket socket_;
}
/**
* Socket implementation of the TTransport interface.
*
* Due to the limitations of std.socket, currently only TCP/IP sockets are
* supported (i.e. Unix domain sockets are not).
*/
class TSocket : TSocketBase {
///
this(Socket socket) {
super(socket);
}
///
this(string host, ushort port) {
super(host, port);
}
/**
* Connects the socket.
*/
override void open() {
if (isOpen) return;
enforce(!host_.empty, new TTransportException(
"Cannot open socket to null host.", TTransportException.Type.NOT_OPEN));
enforce(port_ != 0, new TTransportException(
"Cannot open socket to port zero.", TTransportException.Type.NOT_OPEN));
Address[] addrs;
try {
addrs = getAddress(host_, port_);
} catch (SocketException e) {
throw new TTransportException("Could not resolve given host string.",
TTransportException.Type.NOT_OPEN, __FILE__, __LINE__, e);
}
Exception[] errors;
foreach (addr; addrs) {
try {
socket_ = new TcpSocket(addr.addressFamily);
setSocketOpts();
socket_.connect(addr);
break;
} catch (SocketException e) {
errors ~= e;
}
}
if (errors.length == addrs.length) {
socket_ = null;
// Need to throw a TTransportException to abide the TTransport API.
import std.algorithm, std.range;
throw new TTransportException(
text("Failed to connect to ", host_, ":", port_, "."),
TTransportException.Type.NOT_OPEN,
__FILE__, __LINE__,
new TCompoundOperationException(
text(
"All addresses tried failed (",
joiner(map!q{text(a._0, `: "`, a._1.msg, `"`)}(zip(addrs, errors)), ", "),
")."
),
errors
)
);
}
}
/**
* Closes the socket.
*/
override void close() {
if (!isOpen) return;
socket_.close();
socket_ = null;
}
override bool peek() {
if (!isOpen) return false;
ubyte buf;
auto r = socket_.receive((&buf)[0 .. 1], SocketFlags.PEEK);
if (r == -1) {
auto lastErrno = getSocketErrno();
static if (connresetOnPeerShutdown) {
if (lastErrno == ECONNRESET) {
close();
return false;
}
}
throw new TTransportException("Peeking into socket failed: " ~
socketErrnoString(lastErrno), TTransportException.Type.UNKNOWN);
}
return (r > 0);
}
override size_t read(ubyte[] buf) {
enforce(isOpen, new TTransportException(
"Cannot read if socket is not open.", TTransportException.Type.NOT_OPEN));
typeof(getSocketErrno()) lastErrno;
ushort tries;
while (tries++ <= maxRecvRetries_) {
auto r = socket_.receive(cast(void[])buf);
// If recv went fine, immediately return.
if (r >= 0) return r;
// Something went wrong, find out how to handle it.
lastErrno = getSocketErrno();
if (lastErrno == INTERRUPTED_ERRNO) {
// If the syscall was interrupted, just try again.
continue;
}
static if (connresetOnPeerShutdown) {
// See top comment.
if (lastErrno == ECONNRESET) {
return 0;
}
}
// Not an error which is handled in a special way, just leave the loop.
break;
}
if (isSocketCloseErrno(lastErrno)) {
close();
throw new TTransportException("Receiving failed, closing socket: " ~
socketErrnoString(lastErrno), TTransportException.Type.NOT_OPEN);
} else if (lastErrno == TIMEOUT_ERRNO) {
throw new TTransportException(TTransportException.Type.TIMED_OUT);
} else {
throw new TTransportException("Receiving from socket failed: " ~
socketErrnoString(lastErrno), TTransportException.Type.UNKNOWN);
}
}
override void write(in ubyte[] buf) {
size_t sent;
while (sent < buf.length) {
auto b = writeSome(buf[sent .. $]);
if (b == 0) {
// This should only happen if the timeout set with SO_SNDTIMEO expired.
throw new TTransportException("send() timeout expired.",
TTransportException.Type.TIMED_OUT);
}
sent += b;
}
assert(sent == buf.length);
}
override size_t writeSome(in ubyte[] buf) {
enforce(isOpen, new TTransportException(
"Cannot write if file is not open.", TTransportException.Type.NOT_OPEN));
auto r = socket_.send(buf);
// Everything went well, just return the number of bytes written.
if (r > 0) return r;
// Handle error conditions.
if (r < 0) {
auto lastErrno = getSocketErrno();
if (lastErrno == WOULD_BLOCK_ERRNO) {
// Not an exceptional error per se even with blocking sockets,
// EAGAIN apparently is returned sometimes on out-of-resource
// conditions (see the C++ implementation for details). Also, this
// allows using TSocket with non-blocking sockets e.g. in
// TNonblockingServer.
return 0;
}
auto type = TTransportException.Type.UNKNOWN;
if (isSocketCloseErrno(lastErrno)) {
type = TTransportException.Type.NOT_OPEN;
close();
}
throw new TTransportException("Sending to socket failed: " ~
socketErrnoString(lastErrno), type);
}
// send() should never return 0.
throw new TTransportException("Sending to socket failed (0 bytes written).",
TTransportException.Type.UNKNOWN);
}
override void sendTimeout(Duration value) @property {
super.sendTimeout(value);
setTimeout(SocketOption.SNDTIMEO, value);
}
override void recvTimeout(Duration value) @property {
super.recvTimeout(value);
setTimeout(SocketOption.RCVTIMEO, value);
}
/**
* Maximum number of retries for receiving from socket on read() in case of
* EAGAIN/EINTR.
*/
ushort maxRecvRetries() @property const {
return maxRecvRetries_;
}
/// Ditto
void maxRecvRetries(ushort value) @property {
maxRecvRetries_ = value;
}
/// Ditto
enum DEFAULT_MAX_RECV_RETRIES = 5;
protected:
override void setSocketOpts() {
super.setSocketOpts();
setTimeout(SocketOption.SNDTIMEO, sendTimeout_);
setTimeout(SocketOption.RCVTIMEO, recvTimeout_);
}
void setTimeout(SocketOption type, Duration value) {
assert(type == SocketOption.SNDTIMEO || type == SocketOption.RCVTIMEO);
version (Win32) {
if (value > dur!"hnsecs"(0) && value < dur!"msecs"(500)) {
logError(
"Socket %s timeout of %s ms might be raised to 500 ms on Windows.",
(type == SocketOption.SNDTIMEO) ? "send" : "receive",
value.total!"msecs"
);
}
}
if (socket_) {
try {
socket_.setOption(SocketOptionLevel.SOCKET, type, value);
} catch (SocketException e) {
throw new TTransportException(
"Could not set timeout.",
TTransportException.Type.UNKNOWN,
__FILE__,
__LINE__,
e
);
}
}
}
/// Maximum number of recv() retries.
ushort maxRecvRetries_ = DEFAULT_MAX_RECV_RETRIES;
}

View file

@ -0,0 +1,680 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/**
* OpenSSL socket implementation, in large parts ported from C++.
*/
module thrift.transport.ssl;
import core.exception : onOutOfMemoryError;
import core.stdc.errno : errno, EINTR;
import core.sync.mutex : Mutex;
import core.memory : GC;
import core.stdc.config;
import core.stdc.stdlib : free, malloc;
import std.ascii : toUpper;
import std.array : empty, front, popFront;
import std.conv : emplace, to;
import std.exception : enforce;
import std.socket : Address, InternetAddress, Internet6Address, Socket;
import std.string : toStringz;
import deimos.openssl.err;
import deimos.openssl.rand;
import deimos.openssl.ssl;
import deimos.openssl.x509v3;
import thrift.base;
import thrift.internal.ssl;
import thrift.transport.base;
import thrift.transport.socket;
/**
* SSL encrypted socket implementation using OpenSSL.
*
* Note:
* On Posix systems which do not have the BSD-specific SO_NOSIGPIPE flag, you
* might want to ignore the SIGPIPE signal, as OpenSSL might try to write to
* a closed socket if the peer disconnects abruptly:
* ---
* import core.stdc.signal;
* import core.sys.posix.signal;
* signal(SIGPIPE, SIG_IGN);
* ---
*/
final class TSSLSocket : TSocket {
/**
* Creates an instance that wraps an already created, connected (!) socket.
*
* Params:
* context = The SSL socket context to use. A reference to it is stored so
* that it doesn't get cleaned up while the socket is used.
* socket = Already created, connected socket object.
*/
this(TSSLContext context, Socket socket) {
super(socket);
context_ = context;
serverSide_ = context.serverSide;
accessManager_ = context.accessManager;
}
/**
* Creates a new unconnected socket that will connect to the given host
* on the given port.
*
* Params:
* context = The SSL socket context to use. A reference to it is stored so
* that it doesn't get cleaned up while the socket is used.
* host = Remote host.
* port = Remote port.
*/
this(TSSLContext context, string host, ushort port) {
super(host, port);
context_ = context;
serverSide_ = context.serverSide;
accessManager_ = context.accessManager;
}
override bool isOpen() @property {
if (ssl_ is null || !super.isOpen()) return false;
auto shutdown = SSL_get_shutdown(ssl_);
bool shutdownReceived = (shutdown & SSL_RECEIVED_SHUTDOWN) != 0;
bool shutdownSent = (shutdown & SSL_SENT_SHUTDOWN) != 0;
return !(shutdownReceived && shutdownSent);
}
override bool peek() {
if (!isOpen) return false;
checkHandshake();
byte bt;
auto rc = SSL_peek(ssl_, &bt, bt.sizeof);
enforce(rc >= 0, getSSLException("SSL_peek"));
if (rc == 0) {
ERR_clear_error();
}
return (rc > 0);
}
override void open() {
enforce(!serverSide_, "Cannot open a server-side SSL socket.");
if (isOpen) return;
super.open();
}
override void close() {
if (!isOpen) return;
if (ssl_ !is null) {
// Two-step SSL shutdown.
auto rc = SSL_shutdown(ssl_);
if (rc == 0) {
rc = SSL_shutdown(ssl_);
}
if (rc < 0) {
// Do not throw an exception here as leaving the transport "open" will
// probably produce only more errors, and the chance we can do
// something about the error e.g. by retrying is very low.
logError("Error shutting down SSL: %s", getSSLException());
}
SSL_free(ssl_);
ssl_ = null;
ERR_remove_state(0);
}
super.close();
}
override size_t read(ubyte[] buf) {
checkHandshake();
int bytes;
foreach (_; 0 .. maxRecvRetries) {
bytes = SSL_read(ssl_, buf.ptr, cast(int)buf.length);
if (bytes >= 0) break;
auto errnoCopy = errno;
if (SSL_get_error(ssl_, bytes) == SSL_ERROR_SYSCALL) {
if (ERR_get_error() == 0 && errnoCopy == EINTR) {
// FIXME: Windows.
continue;
}
}
throw getSSLException("SSL_read");
}
return bytes;
}
override void write(in ubyte[] buf) {
checkHandshake();
// Loop in case SSL_MODE_ENABLE_PARTIAL_WRITE is set in SSL_CTX.
size_t written = 0;
while (written < buf.length) {
auto bytes = SSL_write(ssl_, buf.ptr + written,
cast(int)(buf.length - written));
if (bytes <= 0) {
throw getSSLException("SSL_write");
}
written += bytes;
}
}
override void flush() {
checkHandshake();
auto bio = SSL_get_wbio(ssl_);
enforce(bio !is null, new TSSLException("SSL_get_wbio returned null"));
auto rc = BIO_flush(bio);
enforce(rc == 1, getSSLException("BIO_flush"));
}
/**
* Whether to use client or server side SSL handshake protocol.
*/
bool serverSide() @property const {
return serverSide_;
}
/// Ditto
void serverSide(bool value) @property {
serverSide_ = value;
}
/**
* The access manager to use.
*/
void accessManager(TAccessManager value) @property {
accessManager_ = value;
}
private:
void checkHandshake() {
enforce(super.isOpen(), new TTransportException(
TTransportException.Type.NOT_OPEN));
if (ssl_ !is null) return;
ssl_ = context_.createSSL();
SSL_set_fd(ssl_, socketHandle);
int rc;
if (serverSide_) {
rc = SSL_accept(ssl_);
} else {
rc = SSL_connect(ssl_);
}
enforce(rc > 0, getSSLException());
authorize(ssl_, accessManager_, getPeerAddress(),
(serverSide_ ? getPeerAddress().toHostNameString() : host));
}
bool serverSide_;
SSL* ssl_;
TSSLContext context_;
TAccessManager accessManager_;
}
/**
* Represents an OpenSSL context with certification settings, etc. and handles
* initialization/teardown.
*
* OpenSSL is initialized when the first instance of this class is created
* and shut down when the last one is destroyed (thread-safe).
*/
class TSSLContext {
this() {
initMutex_.lock();
scope(exit) initMutex_.unlock();
if (count_ == 0) {
initializeOpenSSL();
randomize();
}
count_++;
ctx_ = SSL_CTX_new(TLSv1_method());
enforce(ctx_, getSSLException("SSL_CTX_new"));
SSL_CTX_set_mode(ctx_, SSL_MODE_AUTO_RETRY);
}
~this() {
initMutex_.lock();
scope(exit) initMutex_.unlock();
if (ctx_ !is null) {
SSL_CTX_free(ctx_);
ctx_ = null;
}
count_--;
if (count_ == 0) {
cleanupOpenSSL();
}
}
/**
* Ciphers to be used in SSL handshake process.
*
* The string must be in the colon-delimited OpenSSL notation described in
* ciphers(1), for example: "ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH".
*/
void ciphers(string enable) @property {
auto rc = SSL_CTX_set_cipher_list(ctx_, toStringz(enable));
enforce(ERR_peek_error() == 0, getSSLException("SSL_CTX_set_cipher_list"));
enforce(rc > 0, new TSSLException("None of specified ciphers are supported"));
}
/**
* Whether peer is required to present a valid certificate.
*/
void authenticate(bool required) @property {
int mode;
if (required) {
mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT |
SSL_VERIFY_CLIENT_ONCE;
} else {
mode = SSL_VERIFY_NONE;
}
SSL_CTX_set_verify(ctx_, mode, null);
}
/**
* Load server certificate.
*
* Params:
* path = Path to the certificate file.
* format = Certificate file format. Defaults to PEM, which is currently
* the only one supported.
*/
void loadCertificate(string path, string format = "PEM") {
enforce(path !is null && format !is null, new TTransportException(
"loadCertificateChain: either <path> or <format> is null",
TTransportException.Type.BAD_ARGS));
if (format == "PEM") {
enforce(SSL_CTX_use_certificate_chain_file(ctx_, toStringz(path)),
getSSLException(
`Could not load SSL server certificate from file "` ~ path ~ `"`
)
);
} else {
throw new TSSLException("Unsupported certificate format: " ~ format);
}
}
/*
* Load private key.
*
* Params:
* path = Path to the certificate file.
* format = Private key file format. Defaults to PEM, which is currently
* the only one supported.
*/
void loadPrivateKey(string path, string format = "PEM") {
enforce(path !is null && format !is null, new TTransportException(
"loadPrivateKey: either <path> or <format> is NULL",
TTransportException.Type.BAD_ARGS));
if (format == "PEM") {
enforce(SSL_CTX_use_PrivateKey_file(ctx_, toStringz(path), SSL_FILETYPE_PEM),
getSSLException(
`Could not load SSL private key from file "` ~ path ~ `"`
)
);
} else {
throw new TSSLException("Unsupported certificate format: " ~ format);
}
}
/**
* Load trusted certificates from specified file (in PEM format).
*
* Params.
* path = Path to the file containing the trusted certificates.
*/
void loadTrustedCertificates(string path) {
enforce(path !is null, new TTransportException(
"loadTrustedCertificates: <path> is NULL",
TTransportException.Type.BAD_ARGS));
enforce(SSL_CTX_load_verify_locations(ctx_, toStringz(path), null),
getSSLException(
`Could not load SSL trusted certificate list from file "` ~ path ~ `"`
)
);
}
/**
* Called during OpenSSL initialization to seed the OpenSSL entropy pool.
*
* Defaults to simply calling RAND_poll(), but it can be overwritten if a
* different, perhaps more secure implementation is desired.
*/
void randomize() {
RAND_poll();
}
/**
* Whether to use client or server side SSL handshake protocol.
*/
bool serverSide() @property const {
return serverSide_;
}
/// Ditto
void serverSide(bool value) @property {
serverSide_ = value;
}
/**
* The access manager to use.
*/
TAccessManager accessManager() @property {
if (!serverSide_ && !accessManager_) {
accessManager_ = new TDefaultClientAccessManager;
}
return accessManager_;
}
/// Ditto
void accessManager(TAccessManager value) @property {
accessManager_ = value;
}
SSL* createSSL() out (result) {
assert(result);
} body {
auto result = SSL_new(ctx_);
enforce(result, getSSLException("SSL_new"));
return result;
}
protected:
/**
* Override this method for custom password callback. It may be called
* multiple times at any time during a session as necessary.
*
* Params:
* size = Maximum length of password, including null byte.
*/
string getPassword(int size) nothrow out(result) {
assert(result.length < size);
} body {
return "";
}
/**
* Notifies OpenSSL to use getPassword() instead of the default password
* callback with getPassword().
*/
void overrideDefaultPasswordCallback() {
SSL_CTX_set_default_passwd_cb(ctx_, &passwordCallback);
SSL_CTX_set_default_passwd_cb_userdata(ctx_, cast(void*)this);
}
SSL_CTX* ctx_;
private:
bool serverSide_;
TAccessManager accessManager_;
shared static this() {
initMutex_ = new Mutex();
}
static void initializeOpenSSL() {
if (initialized_) {
return;
}
initialized_ = true;
SSL_library_init();
SSL_load_error_strings();
mutexes_ = new Mutex[CRYPTO_num_locks()];
foreach (ref m; mutexes_) {
m = new Mutex;
}
import thrift.internal.traits;
// As per the OpenSSL threads manpage, this isn't needed on Windows.
version (Posix) {
CRYPTO_set_id_callback(assumeNothrow(&threadIdCallback));
}
CRYPTO_set_locking_callback(assumeNothrow(&lockingCallback));
CRYPTO_set_dynlock_create_callback(assumeNothrow(&dynlockCreateCallback));
CRYPTO_set_dynlock_lock_callback(assumeNothrow(&dynlockLockCallback));
CRYPTO_set_dynlock_destroy_callback(assumeNothrow(&dynlockDestroyCallback));
}
static void cleanupOpenSSL() {
if (!initialized_) return;
initialized_ = false;
CRYPTO_set_locking_callback(null);
CRYPTO_set_dynlock_create_callback(null);
CRYPTO_set_dynlock_lock_callback(null);
CRYPTO_set_dynlock_destroy_callback(null);
CRYPTO_cleanup_all_ex_data();
ERR_free_strings();
ERR_remove_state(0);
}
static extern(C) {
version (Posix) {
import core.sys.posix.pthread : pthread_self;
c_ulong threadIdCallback() {
return cast(c_ulong)pthread_self();
}
}
void lockingCallback(int mode, int n, const(char)* file, int line) {
if (mode & CRYPTO_LOCK) {
mutexes_[n].lock();
} else {
mutexes_[n].unlock();
}
}
CRYPTO_dynlock_value* dynlockCreateCallback(const(char)* file, int line) {
enum size = __traits(classInstanceSize, Mutex);
auto mem = malloc(size)[0 .. size];
if (!mem) onOutOfMemoryError();
GC.addRange(mem.ptr, size);
auto mutex = emplace!Mutex(mem);
return cast(CRYPTO_dynlock_value*)mutex;
}
void dynlockLockCallback(int mode, CRYPTO_dynlock_value* l,
const(char)* file, int line)
{
if (l is null) return;
if (mode & CRYPTO_LOCK) {
(cast(Mutex)l).lock();
} else {
(cast(Mutex)l).unlock();
}
}
void dynlockDestroyCallback(CRYPTO_dynlock_value* l,
const(char)* file, int line)
{
GC.removeRange(l);
destroy(cast(Mutex)l);
free(l);
}
int passwordCallback(char* password, int size, int, void* data) nothrow {
auto context = cast(TSSLContext) data;
auto userPassword = context.getPassword(size);
auto len = userPassword.length;
if (len > size) {
len = size;
}
password[0 .. len] = userPassword[0 .. len]; // TODO: \0 handling correct?
return cast(int)len;
}
}
static __gshared bool initialized_;
static __gshared Mutex initMutex_;
static __gshared Mutex[] mutexes_;
static __gshared uint count_;
}
/**
* Decides whether a remote host is legitimate or not.
*
* It is usually set at a TSSLContext, which then passes it to all the created
* TSSLSockets.
*/
class TAccessManager {
///
enum Decision {
DENY = -1, /// Deny access.
SKIP = 0, /// Cannot decide, move on to next check (deny if last).
ALLOW = 1 /// Allow access.
}
/**
* Determines whether a peer should be granted access or not based on its
* IP address.
*
* Called once after SSL handshake is completes successfully and before peer
* certificate is examined.
*
* If a valid decision (ALLOW or DENY) is returned, the peer certificate
* will not be verified.
*/
Decision verify(Address address) {
return Decision.DENY;
}
/**
* Determines whether a peer should be granted access or not based on a
* name from its certificate.
*
* Called every time a DNS subjectAltName/common name is extracted from the
* peer's certificate.
*
* Params:
* host = The actual host name string from the socket connection.
* certHost = A host name string from the certificate.
*/
Decision verify(string host, const(char)[] certHost) {
return Decision.DENY;
}
/**
* Determines whether a peer should be granted access or not based on an IP
* address from its certificate.
*
* Called every time an IP subjectAltName is extracted from the peer's
* certificate.
*
* Params:
* address = The actual address from the socket connection.
* certHost = A host name string from the certificate.
*/
Decision verify(Address address, ubyte[] certAddress) {
return Decision.DENY;
}
}
/**
* Default access manager implementation, which just checks the host name
* resp. IP address of the connection against the certificate.
*/
class TDefaultClientAccessManager : TAccessManager {
override Decision verify(Address address) {
return Decision.SKIP;
}
override Decision verify(string host, const(char)[] certHost) {
if (host.empty || certHost.empty) {
return Decision.SKIP;
}
return (matchName(host, certHost) ? Decision.ALLOW : Decision.SKIP);
}
override Decision verify(Address address, ubyte[] certAddress) {
bool match;
if (certAddress.length == 4) {
if (auto ia = cast(InternetAddress)address) {
match = ((cast(ubyte*)ia.addr())[0 .. 4] == certAddress[]);
}
} else if (certAddress.length == 16) {
if (auto ia = cast(Internet6Address)address) {
match = (ia.addr() == certAddress[]);
}
}
return (match ? Decision.ALLOW : Decision.SKIP);
}
}
private {
/**
* Matches a name with a pattern. The pattern may include wildcard. A single
* wildcard "*" can match up to one component in the domain name.
*
* Params:
* host = Host name to match, typically the SSL remote peer.
* pattern = Host name pattern, typically from the SSL certificate.
*
* Returns: true if host matches pattern, false otherwise.
*/
bool matchName(const(char)[] host, const(char)[] pattern) {
while (!host.empty && !pattern.empty) {
if (toUpper(pattern.front) == toUpper(host.front)) {
host.popFront;
pattern.popFront;
} else if (pattern.front == '*') {
while (!host.empty && host.front != '.') {
host.popFront;
}
pattern.popFront;
} else {
break;
}
}
return (host.empty && pattern.empty);
}
unittest {
enforce(matchName("thrift.apache.org", "*.apache.org"));
enforce(!matchName("thrift.apache.org", "apache.org"));
enforce(matchName("thrift.apache.org", "thrift.*.*"));
enforce(matchName("", ""));
enforce(!matchName("", "*"));
}
}
/**
* SSL-level exception.
*/
class TSSLException : TTransportException {
///
this(string msg, string file = __FILE__, size_t line = __LINE__,
Throwable next = null)
{
super(msg, TTransportException.Type.INTERNAL_ERROR, file, line, next);
}
}

View file

@ -0,0 +1,497 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.transport.zlib;
import core.bitop : bswap;
import etc.c.zlib;
import std.algorithm : min;
import std.array : empty;
import std.conv : to;
import std.exception : enforce;
import thrift.base;
import thrift.transport.base;
/**
* zlib transport. Compresses (deflates) data before writing it to the
* underlying transport, and decompresses (inflates) it after reading.
*/
final class TZlibTransport : TBaseTransport {
// These defaults have yet to be optimized.
enum DEFAULT_URBUF_SIZE = 128;
enum DEFAULT_CRBUF_SIZE = 1024;
enum DEFAULT_UWBUF_SIZE = 128;
enum DEFAULT_CWBUF_SIZE = 1024;
/**
* Constructs a new zlib transport.
*
* Params:
* transport = The underlying transport to wrap.
* urbufSize = The size of the uncompressed reading buffer, in bytes.
* crbufSize = The size of the compressed reading buffer, in bytes.
* uwbufSize = The size of the uncompressed writing buffer, in bytes.
* cwbufSize = The size of the compressed writing buffer, in bytes.
*/
this(
TTransport transport,
size_t urbufSize = DEFAULT_URBUF_SIZE,
size_t crbufSize = DEFAULT_CRBUF_SIZE,
size_t uwbufSize = DEFAULT_UWBUF_SIZE,
size_t cwbufSize = DEFAULT_CWBUF_SIZE
) {
transport_ = transport;
enforce(uwbufSize >= MIN_DIRECT_DEFLATE_SIZE, new TTransportException(
"TZLibTransport: uncompressed write buffer must be at least " ~
to!string(MIN_DIRECT_DEFLATE_SIZE) ~ "bytes in size.",
TTransportException.Type.BAD_ARGS));
urbuf_ = new ubyte[urbufSize];
crbuf_ = new ubyte[crbufSize];
uwbuf_ = new ubyte[uwbufSize];
cwbuf_ = new ubyte[cwbufSize];
rstream_ = new z_stream;
rstream_.next_in = crbuf_.ptr;
rstream_.avail_in = 0;
rstream_.next_out = urbuf_.ptr;
rstream_.avail_out = to!uint(urbuf_.length);
wstream_ = new z_stream;
wstream_.next_in = uwbuf_.ptr;
wstream_.avail_in = 0;
wstream_.next_out = cwbuf_.ptr;
wstream_.avail_out = to!uint(crbuf_.length);
zlibEnforce(inflateInit(rstream_), rstream_);
scope (failure) {
zlibLogError(inflateEnd(rstream_), rstream_);
}
zlibEnforce(deflateInit(wstream_, Z_DEFAULT_COMPRESSION), wstream_);
}
~this() {
zlibLogError(inflateEnd(rstream_), rstream_);
auto result = deflateEnd(wstream_);
// Z_DATA_ERROR may indicate unflushed data, so just ignore it.
if (result != Z_DATA_ERROR) {
zlibLogError(result, wstream_);
}
}
/**
* Returns the wrapped transport.
*/
TTransport underlyingTransport() @property {
return transport_;
}
override bool isOpen() @property {
return readAvail > 0 || transport_.isOpen;
}
override bool peek() {
return readAvail > 0 || transport_.peek();
}
override void open() {
transport_.open();
}
override void close() {
transport_.close();
}
override size_t read(ubyte[] buf) {
// The C++ implementation suggests to skip urbuf on big reads in future
// versions, we would benefit from it as well.
auto origLen = buf.length;
while (true) {
auto give = min(readAvail, buf.length);
// If std.range.put was optimized for slicable ranges, it could be used
// here as well.
buf[0 .. give] = urbuf_[urpos_ .. urpos_ + give];
buf = buf[give .. $];
urpos_ += give;
auto need = buf.length;
if (need == 0) {
// We could manage to get the all the data requested.
return origLen;
}
if (inputEnded_ || (need < origLen && rstream_.avail_in == 0)) {
// We didn't fill buf completely, but there is no more data available.
return origLen - need;
}
// Refill our buffer by reading more data through zlib.
rstream_.next_out = urbuf_.ptr;
rstream_.avail_out = to!uint(urbuf_.length);
urpos_ = 0;
if (!readFromZlib()) {
// Couldn't get more data from the underlying transport.
return origLen - need;
}
}
}
override void write(in ubyte[] buf) {
enforce(!outputFinished_, new TTransportException(
"write() called after finish()", TTransportException.Type.BAD_ARGS));
auto len = buf.length;
if (len > MIN_DIRECT_DEFLATE_SIZE) {
flushToZlib(uwbuf_[0 .. uwpos_], Z_NO_FLUSH);
uwpos_ = 0;
flushToZlib(buf, Z_NO_FLUSH);
} else if (len > 0) {
if (uwbuf_.length - uwpos_ < len) {
flushToZlib(uwbuf_[0 .. uwpos_], Z_NO_FLUSH);
uwpos_ = 0;
}
uwbuf_[uwpos_ .. uwpos_ + len] = buf[];
uwpos_ += len;
}
}
override void flush() {
enforce(!outputFinished_, new TTransportException(
"flush() called after finish()", TTransportException.Type.BAD_ARGS));
flushToTransport(Z_SYNC_FLUSH);
}
override const(ubyte)[] borrow(ubyte* buf, size_t len) {
if (len <= readAvail) {
return urbuf_[urpos_ .. $];
}
return null;
}
override void consume(size_t len) {
enforce(readAvail >= len, new TTransportException(
"consume() did not follow a borrow().", TTransportException.Type.BAD_ARGS));
urpos_ += len;
}
/**
* Finalize the zlib stream.
*
* This causes zlib to flush any pending write data and write end-of-stream
* information, including the checksum. Once finish() has been called, no
* new data can be written to the stream.
*/
void finish() {
enforce(!outputFinished_, new TTransportException(
"flush() called on already finished TZlibTransport",
TTransportException.Type.BAD_ARGS));
flushToTransport(Z_FINISH);
}
/**
* Verify the checksum at the end of the zlib stream (by finish()).
*
* May only be called after all data has been read.
*
* Throws: TTransportException when the checksum is corrupted or there is
* still unread data left.
*/
void verifyChecksum() {
// If zlib has already reported the end of the stream, the checksum has
// been verified, no.
if (inputEnded_) return;
enforce(!readAvail, new TTransportException(
"verifyChecksum() called before end of zlib stream",
TTransportException.Type.CORRUPTED_DATA));
rstream_.next_out = urbuf_.ptr;
rstream_.avail_out = to!uint(urbuf_.length);
urpos_ = 0;
// readFromZlib() will throw an exception if the checksum is bad.
enforce(readFromZlib(), new TTransportException(
"checksum not available yet in verifyChecksum()",
TTransportException.Type.CORRUPTED_DATA));
enforce(inputEnded_, new TTransportException(
"verifyChecksum() called before end of zlib stream",
TTransportException.Type.CORRUPTED_DATA));
// If we get here, we are at the end of the stream and thus zlib has
// successfully verified the checksum.
}
private:
size_t readAvail() const @property {
return urbuf_.length - rstream_.avail_out - urpos_;
}
bool readFromZlib() {
assert(!inputEnded_);
if (rstream_.avail_in == 0) {
// zlib has used up all the compressed data we provided in crbuf, read
// some more from the underlying transport.
auto got = transport_.read(crbuf_);
if (got == 0) return false;
rstream_.next_in = crbuf_.ptr;
rstream_.avail_in = to!uint(got);
}
// We have some compressed data now, uncompress it.
auto zlib_result = inflate(rstream_, Z_SYNC_FLUSH);
if (zlib_result == Z_STREAM_END) {
inputEnded_ = true;
} else {
zlibEnforce(zlib_result, rstream_);
}
return true;
}
void flushToTransport(int type) {
// Compress remaining data in uwbuf_ to cwbuf_.
flushToZlib(uwbuf_[0 .. uwpos_], type);
uwpos_ = 0;
// Write all compressed data to the transport.
transport_.write(cwbuf_[0 .. $ - wstream_.avail_out]);
wstream_.next_out = cwbuf_.ptr;
wstream_.avail_out = to!uint(cwbuf_.length);
// Flush the transport.
transport_.flush();
}
void flushToZlib(in ubyte[] buf, int type) {
wstream_.next_in = cast(ubyte*)buf.ptr; // zlib only reads, cast is safe.
wstream_.avail_in = to!uint(buf.length);
while (true) {
if (type == Z_NO_FLUSH && wstream_.avail_in == 0) {
break;
}
if (wstream_.avail_out == 0) {
// cwbuf has been exhausted by zlib, flush to the underlying transport.
transport_.write(cwbuf_);
wstream_.next_out = cwbuf_.ptr;
wstream_.avail_out = to!uint(cwbuf_.length);
}
auto zlib_result = deflate(wstream_, type);
if (type == Z_FINISH && zlib_result == Z_STREAM_END) {
assert(wstream_.avail_in == 0);
outputFinished_ = true;
break;
}
zlibEnforce(zlib_result, wstream_);
if ((type == Z_SYNC_FLUSH || type == Z_FULL_FLUSH) &&
wstream_.avail_in == 0 && wstream_.avail_out != 0) {
break;
}
}
}
static void zlibEnforce(int status, z_stream* stream) {
if (status != Z_OK) {
throw new TZlibException(status, stream.msg);
}
}
static void zlibLogError(int status, z_stream* stream) {
if (status != Z_OK) {
logError("TZlibTransport: zlib failure in destructor: %s",
TZlibException.errorMessage(status, stream.msg));
}
}
// Writes smaller than this are buffered up (due to zlib handling overhead).
// Larger (or equal) writes are dumped straight to zlib.
enum MIN_DIRECT_DEFLATE_SIZE = 32;
TTransport transport_;
z_stream* rstream_;
z_stream* wstream_;
/// Whether zlib has reached the end of the input stream.
bool inputEnded_;
/// Whether the output stream was already finish()ed.
bool outputFinished_;
/// Compressed input data buffer.
ubyte[] crbuf_;
/// Uncompressed input data buffer.
ubyte[] urbuf_;
size_t urpos_;
/// Uncompressed output data buffer (where small writes are accumulated
/// before handing over to zlib).
ubyte[] uwbuf_;
size_t uwpos_;
/// Compressed output data buffer (filled by zlib, we flush it to the
/// underlying transport).
ubyte[] cwbuf_;
}
/**
* Wraps given transports into TZlibTransports.
*/
alias TWrapperTransportFactory!TZlibTransport TZlibTransportFactory;
/**
* An INTERNAL_ERROR-type TTransportException originating from an error
* signaled by zlib.
*/
class TZlibException : TTransportException {
this(int statusCode, const(char)* msg) {
super(errorMessage(statusCode, msg), TTransportException.Type.INTERNAL_ERROR);
zlibStatusCode = statusCode;
zlibMsg = msg ? to!string(msg) : "(null)";
}
int zlibStatusCode;
string zlibMsg;
static string errorMessage(int statusCode, const(char)* msg) {
string result = "zlib error: ";
if (msg) {
result ~= to!string(msg);
} else {
result ~= "(no message)";
}
result ~= " (status code = " ~ to!string(statusCode) ~ ")";
return result;
}
}
version (unittest) {
import std.exception : collectException;
import thrift.transport.memory;
}
// Make sure basic reading/writing works.
unittest {
auto buf = new TMemoryBuffer;
auto zlib = new TZlibTransport(buf);
immutable ubyte[] data = [1, 2, 3, 4, 5];
zlib.write(data);
zlib.finish();
auto result = new ubyte[data.length];
zlib.readAll(result);
enforce(data == result);
zlib.verifyChecksum();
}
// Make sure there is no data is written if write() is never called.
unittest {
auto buf = new TMemoryBuffer;
{
scope zlib = new TZlibTransport(buf);
}
enforce(buf.getContents().length == 0);
}
// Make sure calling write()/flush()/finish() again after finish() throws.
unittest {
auto buf = new TMemoryBuffer;
auto zlib = new TZlibTransport(buf);
zlib.write([1, 2, 3, 4, 5]);
zlib.finish();
auto ex = collectException!TTransportException(zlib.write([6]));
enforce(ex && ex.type == TTransportException.Type.BAD_ARGS);
ex = collectException!TTransportException(zlib.flush());
enforce(ex && ex.type == TTransportException.Type.BAD_ARGS);
ex = collectException!TTransportException(zlib.finish());
enforce(ex && ex.type == TTransportException.Type.BAD_ARGS);
}
// Make sure verifying the checksum works even if it requires starting a new
// reading buffer after reading the payload has already been completed.
unittest {
auto buf = new TMemoryBuffer;
auto zlib = new TZlibTransport(buf);
immutable ubyte[] data = [1, 2, 3, 4, 5];
zlib.write(data);
zlib.finish();
zlib = new TZlibTransport(buf, TZlibTransport.DEFAULT_URBUF_SIZE,
buf.getContents().length - 1); // The last byte belongs to the checksum.
auto result = new ubyte[data.length];
zlib.readAll(result);
enforce(data == result);
zlib.verifyChecksum();
}
// Make sure verifyChecksum() throws if we messed with the checksum.
unittest {
import std.stdio;
import thrift.transport.range;
auto buf = new TMemoryBuffer;
auto zlib = new TZlibTransport(buf);
immutable ubyte[] data = [1, 2, 3, 4, 5];
zlib.write(data);
zlib.finish();
void testCorrupted(const(ubyte)[] corruptedData) {
auto reader = new TZlibTransport(tInputRangeTransport(corruptedData));
auto result = new ubyte[data.length];
try {
reader.readAll(result);
// If it does read without complaining, the result should be correct.
enforce(result == data);
} catch (TZlibException e) {}
auto ex = collectException!TTransportException(reader.verifyChecksum());
enforce(ex && ex.type == TTransportException.Type.CORRUPTED_DATA);
}
testCorrupted(buf.getContents()[0 .. $ - 1]);
auto modified = buf.getContents().dup;
++modified[$ - 1];
testCorrupted(modified);
}

View file

@ -0,0 +1,212 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.util.awaitable;
import core.sync.condition;
import core.sync.mutex;
import core.time : Duration;
import std.exception : enforce;
import std.socket/+ : Socket, socketPair+/; // DMD @@BUG314@@
import thrift.base;
// To avoid DMD @@BUG6395@@.
import thrift.internal.algorithm;
/**
* An event that can occur at some point in the future and which can be
* awaited, either by blocking until it occurs, or by registering a callback
* delegate.
*/
interface TAwaitable {
/**
* Waits until the event occurs.
*
* Calling wait() for an event that has already occurred is a no-op.
*/
void wait();
/**
* Waits until the event occurs or the specified timeout expires.
*
* Calling wait() for an event that has already occurred is a no-op.
*
* Returns: Whether the event was triggered before the timeout expired.
*/
bool wait(Duration timeout);
/**
* Registers a callback that is called if the event occurs.
*
* The delegate will likely be invoked from a different thread, and is
* expected not to perform expensive work as it will usually be invoked
* synchronously by the notifying thread. The order in which registered
* callbacks are invoked is not specified.
*
* The callback must never throw, but nothrow semantics are difficult to
* enforce, so currently exceptions are just swallowed by
* TAwaitable implementations.
*
* If the event has already occurred, the delegate is immediately executed
* in the current thread.
*/
void addCallback(void delegate() dg);
/**
* Removes a previously added callback.
*
* Returns: Whether the callback could be found in the list, i.e. whether it
* was previously added.
*/
bool removeCallback(void delegate() dg);
}
/**
* A simple TAwaitable event triggered by just calling a trigger() method.
*/
class TOneshotEvent : TAwaitable {
this() {
mutex_ = new Mutex;
condition_ = new Condition(mutex_);
}
override void wait() {
synchronized (mutex_) {
while (!triggered_) condition_.wait();
}
}
override bool wait(Duration timeout) {
synchronized (mutex_) {
if (triggered_) return true;
condition_.wait(timeout);
return triggered_;
}
}
override void addCallback(void delegate() dg) {
mutex_.lock();
scope (failure) mutex_.unlock();
callbacks_ ~= dg;
if (triggered_) {
mutex_.unlock();
dg();
return;
}
mutex_.unlock();
}
override bool removeCallback(void delegate() dg) {
synchronized (mutex_) {
auto oldLength = callbacks_.length;
callbacks_ = removeEqual(callbacks_, dg);
return callbacks_.length < oldLength;
}
}
/**
* Triggers the event.
*
* Any registered event callbacks are executed synchronously before the
* function returns.
*/
void trigger() {
synchronized (mutex_) {
if (!triggered_) {
triggered_ = true;
condition_.notifyAll();
foreach (c; callbacks_) c();
}
}
}
private:
bool triggered_;
Mutex mutex_;
Condition condition_;
void delegate()[] callbacks_;
}
/**
* Translates TAwaitable events into dummy messages on a socket that can be
* used e.g. to wake up from a select() call.
*/
final class TSocketNotifier {
this() {
auto socks = socketPair();
foreach (s; socks) s.blocking = false;
sendSocket_ = socks[0];
recvSocket_ = socks[1];
}
/**
* The socket the messages will be sent to.
*/
Socket socket() @property {
return recvSocket_;
}
/**
* Atatches the socket notifier to the specified awaitable, causing it to
* write a byte to the notification socket when the awaitable callbacks are
* invoked.
*
* If the event has already been triggered, the dummy byte is written
* immediately to the socket.
*
* A socket notifier can only be attached to a single awaitable at a time.
*
* Throws: TException if the socket notifier is already attached.
*/
void attach(TAwaitable awaitable) {
enforce(!awaitable_, new TException("Already attached."));
awaitable.addCallback(&notify);
awaitable_ = awaitable;
}
/**
* Detaches the socket notifier from the awaitable it is currently attached
* to.
*
* Throws: TException if the socket notifier is not currently attached.
*/
void detach() {
enforce(awaitable_, new TException("Not attached."));
// Soak up any not currently read notification bytes.
ubyte[1] dummy = void;
while (recvSocket_.receive(dummy) != Socket.ERROR) {}
auto couldRemove = awaitable_.removeCallback(&notify);
assert(couldRemove);
awaitable_ = null;
}
private:
void notify() {
ubyte[1] zero;
sendSocket_.send(zero);
}
TAwaitable awaitable_;
Socket sendSocket_;
Socket recvSocket_;
}

View file

@ -0,0 +1,105 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.util.cancellation;
import core.atomic;
import thrift.base;
import thrift.util.awaitable;
/**
* A cancellation request for asynchronous or blocking synchronous operations.
*
* It is passed to the entity creating an operation, which will usually monitor
* it either by polling or by adding event handlers, and cancel the operation
* if it is triggered.
*
* For synchronous operations, this usually means either throwing a
* TCancelledException or immediately returning, depending on whether
* cancellation is an expected part of the task outcome or not. For
* asynchronous operations, cancellation typically entails stopping background
* work and cancelling a result future, if not already completed.
*
* An operation accepting a TCancellation does not need to guarantee that it
* will actually be able to react to the cancellation request.
*/
interface TCancellation {
/**
* Whether the cancellation request has been triggered.
*/
bool triggered() const @property;
/**
* Throws a TCancelledException if the cancellation request has already been
* triggered.
*/
void throwIfTriggered() const;
/**
* A TAwaitable that can be used to wait for cancellation triggering.
*/
TAwaitable triggering() @property;
}
/**
* The origin of a cancellation request, which provides a way to actually
* trigger it.
*
* This design allows operations to pass the TCancellation on to sub-tasks,
* while making sure that the cancellation can only be triggered by the
* »outermost« instance waiting for the result.
*/
final class TCancellationOrigin : TCancellation {
this() {
event_ = new TOneshotEvent;
}
/**
* Triggers the cancellation request.
*/
void trigger() {
atomicStore(triggered_, true);
event_.trigger();
}
/+override+/ bool triggered() const @property {
return atomicLoad(triggered_);
}
/+override+/ void throwIfTriggered() const {
if (triggered) throw new TCancelledException;
}
/+override+/ TAwaitable triggering() @property {
return event_;
}
private:
shared bool triggered_;
TOneshotEvent event_;
}
///
class TCancelledException : TException {
///
this(string msg = null, string file = __FILE__, size_t line = __LINE__,
Throwable next = null
) {
super(msg ? msg : "The operation has been cancelled.", file, line, next);
}
}

View file

@ -0,0 +1,549 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.util.future;
import core.atomic;
import core.sync.condition;
import core.sync.mutex;
import core.time : Duration;
import std.array : empty, front, popFront;
import std.conv : to;
import std.exception : enforce;
import std.traits : BaseTypeTuple, isSomeFunction, ParameterTypeTuple, ReturnType;
import thrift.base;
import thrift.util.awaitable;
import thrift.util.cancellation;
/**
* Represents an operation which is executed asynchronously and the result of
* which will become available at some point in the future.
*
* Once a operation is completed, the result of the operation can be fetched
* via the get() family of methods. There are three possible cases: Either the
* operation succeeded, then its return value is returned, or it failed by
* throwing, in which case the exception is rethrown, or it was cancelled
* before, then a TCancelledException is thrown. There might be TFuture
* implementations which never possibly enter the cancelled state.
*
* All methods are thread-safe, but keep in mind that any exception object or
* result (if it is a reference type, of course) is shared between all
* get()-family invocations.
*/
interface TFuture(ResultType) {
/**
* The status the operation is currently in.
*
* An operation starts out in RUNNING status, and changes state to one of the
* others at most once afterwards.
*/
TFutureStatus status() @property;
/**
* A TAwaitable triggered when the operation leaves the RUNNING status.
*/
TAwaitable completion() @property;
/**
* Convenience shorthand for waiting until the result is available and then
* get()ing it.
*
* If the operation has already completed, the result is immediately
* returned.
*
* The result of this method is »alias this«'d to the interface, so that
* TFuture can be used as a drop-in replacement for a simple value in
* synchronous code.
*/
final ResultType waitGet() {
completion.wait();
return get();
}
final @property auto waitGetProperty() { return waitGet(); }
alias waitGetProperty this;
/**
* Convenience shorthand for waiting until the result is available and then
* get()ing it.
*
* If the operation completes in time, returns its result (resp. throws an
* exception for the failed/cancelled cases). If not, throws a
* TFutureException.
*/
final ResultType waitGet(Duration timeout) {
enforce(completion.wait(timeout), new TFutureException(
"Operation did not complete in time."));
return get();
}
/**
* Returns the result of the operation.
*
* Throws: TFutureException if the operation has been cancelled,
* TCancelledException if it is not yet done; the set exception if it
* failed.
*/
ResultType get();
/**
* Returns the captured exception if the operation failed, or null otherwise.
*
* Throws: TFutureException if not yet done, TCancelledException if the
* operation has been cancelled.
*/
Exception getException();
}
/**
* The states the operation offering a future interface can be in.
*/
enum TFutureStatus : byte {
RUNNING, /// The operation is still running.
SUCCEEDED, /// The operation completed without throwing an exception.
FAILED, /// The operation completed by throwing an exception.
CANCELLED /// The operation was cancelled.
}
/**
* A TFuture covering the simple but common case where the result is simply
* set by a call to succeed()/fail().
*
* All methods are thread-safe, but usually, succeed()/fail() are only called
* from a single thread (different from the thread(s) waiting for the result
* using the TFuture interface, though).
*/
class TPromise(ResultType) : TFuture!ResultType {
this() {
statusMutex_ = new Mutex;
completionEvent_ = new TOneshotEvent;
}
override S status() const @property {
return atomicLoad(status_);
}
override TAwaitable completion() @property {
return completionEvent_;
}
override ResultType get() {
auto s = atomicLoad(status_);
enforce(s != S.RUNNING,
new TFutureException("Operation not yet completed."));
if (s == S.CANCELLED) throw new TCancelledException;
if (s == S.FAILED) throw exception_;
static if (!is(ResultType == void)) {
return result_;
}
}
override Exception getException() {
auto s = atomicLoad(status_);
enforce(s != S.RUNNING,
new TFutureException("Operation not yet completed."));
if (s == S.CANCELLED) throw new TCancelledException;
if (s == S.SUCCEEDED) return null;
return exception_;
}
static if (!is(ResultType == void)) {
/**
* Sets the result of the operation, marks it as done, and notifies any
* waiters.
*
* If the operation has been cancelled before, nothing happens.
*
* Throws: TFutureException if the operation is already completed.
*/
void succeed(ResultType result) {
synchronized (statusMutex_) {
auto s = atomicLoad(status_);
if (s == S.CANCELLED) return;
enforce(s == S.RUNNING,
new TFutureException("Operation already completed."));
result_ = result;
atomicStore(status_, S.SUCCEEDED);
}
completionEvent_.trigger();
}
} else {
void succeed() {
synchronized (statusMutex_) {
auto s = atomicLoad(status_);
if (s == S.CANCELLED) return;
enforce(s == S.RUNNING,
new TFutureException("Operation already completed."));
atomicStore(status_, S.SUCCEEDED);
}
completionEvent_.trigger();
}
}
/**
* Marks the operation as failed with the specified exception and notifies
* any waiters.
*
* If the operation was already cancelled, nothing happens.
*
* Throws: TFutureException if the operation is already completed.
*/
void fail(Exception exception) {
synchronized (statusMutex_) {
auto status = atomicLoad(status_);
if (status == S.CANCELLED) return;
enforce(status == S.RUNNING,
new TFutureException("Operation already completed."));
exception_ = exception;
atomicStore(status_, S.FAILED);
}
completionEvent_.trigger();
}
/**
* Marks this operation as completed and takes over the outcome of another
* TFuture of the same type.
*
* If this operation was already cancelled, nothing happens. If the other
* operation was cancelled, this operation is marked as failed with a
* TCancelledException.
*
* Throws: TFutureException if the passed in future was not completed or
* this operation is already completed.
*/
void complete(TFuture!ResultType future) {
synchronized (statusMutex_) {
auto status = atomicLoad(status_);
if (status == S.CANCELLED) return;
enforce(status == S.RUNNING,
new TFutureException("Operation already completed."));
enforce(future.status != S.RUNNING, new TFutureException(
"The passed TFuture is not yet completed."));
status = future.status;
if (status == S.CANCELLED) {
status = S.FAILED;
exception_ = new TCancelledException;
} else if (status == S.FAILED) {
exception_ = future.getException();
} else static if (!is(ResultType == void)) {
result_ = future.get();
}
atomicStore(status_, status);
}
completionEvent_.trigger();
}
/**
* Marks this operation as cancelled and notifies any waiters.
*
* If the operation is already completed, nothing happens.
*/
void cancel() {
synchronized (statusMutex_) {
auto status = atomicLoad(status_);
if (status == S.RUNNING) atomicStore(status_, S.CANCELLED);
}
completionEvent_.trigger();
}
private:
// Convenience alias because TFutureStatus is ubiquitous in this class.
alias TFutureStatus S;
// The status the promise is currently in.
shared S status_;
union {
static if (!is(ResultType == void)) {
// Set if status_ is SUCCEEDED.
ResultType result_;
}
// Set if status_ is FAILED.
Exception exception_;
}
// Protects status_.
// As for result_ and exception_: They are only set once, while status_ is
// still RUNNING, so given that the operation has already completed, reading
// them is safe without holding some kind of lock.
Mutex statusMutex_;
// Triggered when the event completes.
TOneshotEvent completionEvent_;
}
///
class TFutureException : TException {
///
this(string msg = "", string file = __FILE__, size_t line = __LINE__,
Throwable next = null)
{
super(msg, file, line, next);
}
}
/**
* Creates an interface that is similar to a given one, but accepts an
* additional, optional TCancellation parameter each method, and returns
* TFutures instead of plain return values.
*
* For example, given the following declarations:
* ---
* interface Foo {
* void bar();
* string baz(int a);
* }
* alias TFutureInterface!Foo FutureFoo;
* ---
*
* FutureFoo would be equivalent to:
* ---
* interface FutureFoo {
* TFuture!void bar(TCancellation cancellation = null);
* TFuture!string baz(int a, TCancellation cancellation = null);
* }
* ---
*/
template TFutureInterface(Interface) if (is(Interface _ == interface)) {
mixin({
string code = "interface TFutureInterface \n";
static if (is(Interface Bases == super) && Bases.length > 0) {
code ~= ": ";
foreach (i; 0 .. Bases.length) {
if (i > 0) code ~= ", ";
code ~= "TFutureInterface!(BaseTypeTuple!Interface[" ~ to!string(i) ~ "]) ";
}
}
code ~= "{\n";
foreach (methodName; __traits(derivedMembers, Interface)) {
enum qn = "Interface." ~ methodName;
static if (isSomeFunction!(mixin(qn))) {
code ~= "TFuture!(ReturnType!(" ~ qn ~ ")) " ~ methodName ~
"(ParameterTypeTuple!(" ~ qn ~ "), TCancellation cancellation = null);\n";
}
}
code ~= "}\n";
return code;
}());
}
/**
* An input range that aggregates results from multiple asynchronous operations,
* returning them in the order they arrive.
*
* Additionally, a timeout can be set after which results from not yet finished
* futures will no longer be waited for, e.g. to ensure the time it takes to
* iterate over a set of results is limited.
*/
final class TFutureAggregatorRange(T) {
/**
* Constructs a new instance.
*
* Params:
* futures = The set of futures to collect results from.
* timeout = If positive, not yet finished futures will be cancelled and
* their results will not be taken into account.
*/
this(TFuture!T[] futures, TCancellationOrigin childCancellation,
Duration timeout = dur!"hnsecs"(0)
) {
if (timeout > dur!"hnsecs"(0)) {
timeoutSysTick_ = TickDuration.currSystemTick +
TickDuration.from!"hnsecs"(timeout.total!"hnsecs");
} else {
timeoutSysTick_ = TickDuration(0);
}
queueMutex_ = new Mutex;
queueNonEmptyCondition_ = new Condition(queueMutex_);
futures_ = futures;
childCancellation_ = childCancellation;
foreach (future; futures_) {
future.completion.addCallback({
auto f = future;
return {
if (f.status == TFutureStatus.CANCELLED) return;
assert(f.status != TFutureStatus.RUNNING);
synchronized (queueMutex_) {
completedQueue_ ~= f;
if (completedQueue_.length == 1) {
queueNonEmptyCondition_.notifyAll();
}
}
};
}());
}
}
/**
* Whether the range is empty.
*
* This is the case if the results from the completed futures not having
* failed have already been popped and either all future have been finished
* or the timeout has expired.
*
* Potentially blocks until a new result is available or the timeout has
* expired.
*/
bool empty() @property {
if (finished_) return true;
if (bufferFilled_) return false;
while (true) {
TFuture!T future;
synchronized (queueMutex_) {
// The while loop is just being cautious about spurious wakeups, in
// case they should be possible.
while (completedQueue_.empty) {
auto remaining = to!Duration(timeoutSysTick_ -
TickDuration.currSystemTick);
if (remaining <= dur!"hnsecs"(0)) {
// No time left, but still no element received we are empty now.
finished_ = true;
childCancellation_.trigger();
return true;
}
queueNonEmptyCondition_.wait(remaining);
}
future = completedQueue_.front;
completedQueue_.popFront();
}
++completedCount_;
if (completedCount_ == futures_.length) {
// This was the last future in the list, there is no possibility
// another result could ever become available.
finished_ = true;
}
if (future.status == TFutureStatus.FAILED) {
// This one failed, loop again and try getting another item from
// the queue.
exceptions_ ~= future.getException();
} else {
resultBuffer_ = future.get();
bufferFilled_ = true;
return false;
}
}
}
/**
* Returns the first element from the range.
*
* Potentially blocks until a new result is available or the timeout has
* expired.
*
* Throws: TException if the range is empty.
*/
T front() {
enforce(!empty, new TException(
"Cannot get front of an empty future aggregator range."));
return resultBuffer_;
}
/**
* Removes the first element from the range.
*
* Potentially blocks until a new result is available or the timeout has
* expired.
*
* Throws: TException if the range is empty.
*/
void popFront() {
enforce(!empty, new TException(
"Cannot pop front of an empty future aggregator range."));
bufferFilled_ = false;
}
/**
* The number of futures the result of which has been returned or which have
* failed so far.
*/
size_t completedCount() @property const {
return completedCount_;
}
/**
* The exceptions collected from failed TFutures so far.
*/
Exception[] exceptions() @property {
return exceptions_;
}
private:
TFuture!T[] futures_;
TCancellationOrigin childCancellation_;
// The system tick this operation will time out, or zero if no timeout has
// been set.
TickDuration timeoutSysTick_;
bool finished_;
bool bufferFilled_;
T resultBuffer_;
Exception[] exceptions_;
size_t completedCount_;
// The queue of completed futures. This (and the associated condition) are
// the only parts of this class that are accessed by multiple threads.
TFuture!T[] completedQueue_;
Mutex queueMutex_;
Condition queueNonEmptyCondition_;
}
/**
* TFutureAggregatorRange construction helper to avoid having to explicitly
* specify the value type, i.e. to allow the constructor being called using IFTI
* (see $(DMDBUG 6082, D Bugzilla enhancement requet 6082)).
*/
TFutureAggregatorRange!T tFutureAggregatorRange(T)(TFuture!T[] futures,
TCancellationOrigin childCancellation, Duration timeout = dur!"hnsecs"(0)
) {
return new TFutureAggregatorRange!T(futures, childCancellation, timeout);
}

View file

@ -0,0 +1,146 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
module thrift.util.hashset;
import std.algorithm : joiner, map;
import std.conv : to;
import std.traits : isImplicitlyConvertible, ParameterTypeTuple;
import std.range : ElementType, isInputRange;
struct Void {}
/**
* A quickly hacked together hash set implementation backed by built-in
* associative arrays to have something to compile Thrift's set<> to until
* std.container gains something suitable.
*/
// Note: The funky pointer casts (i.e. *(cast(immutable(E)*)&e) instead of
// just cast(immutable(E))e) are a workaround for LDC 2 compatibility.
final class HashSet(E) {
///
this() {}
///
this(E[] elems...) {
insert(elems);
}
///
void insert(Stuff)(Stuff stuff) if (isImplicitlyConvertible!(Stuff, E)) {
aa_[*(cast(immutable(E)*)&stuff)] = Void.init;
}
///
void insert(Stuff)(Stuff stuff) if (
isInputRange!Stuff && isImplicitlyConvertible!(ElementType!Stuff, E)
) {
foreach (e; stuff) {
aa_[*(cast(immutable(E)*)&e)] = Void.init;
}
}
///
void opOpAssign(string op : "~", Stuff)(Stuff stuff) {
insert(stuff);
}
///
void remove(E e) {
aa_.remove(*(cast(immutable(E)*)&e));
}
alias remove removeKey;
///
void removeAll() {
aa_ = null;
}
///
size_t length() @property const {
return aa_.length;
}
///
size_t empty() @property const {
return !aa_.length;
}
///
bool opBinaryRight(string op : "in")(E e) const {
return (e in aa_) !is null;
}
///
auto opSlice() const {
// TODO: Implement using AA key range once available in release DMD/druntime
// to avoid allocation.
return cast(E[])(aa_.keys);
}
///
override string toString() const {
// Only provide toString() if to!string() is available for E (exceptions are
// e.g. delegates).
static if (is(typeof(to!string(E.init)) : string)) {
return "{" ~ to!string(joiner(map!`to!string(a)`(aa_.keys), ", ")) ~ "}";
} else {
// Cast to work around Object not being const-correct.
return (cast()super).toString();
}
}
///
override bool opEquals(Object other) const {
auto rhs = cast(const(HashSet))other;
if (rhs) {
return aa_ == rhs.aa_;
}
// Cast to work around Object not being const-correct.
return (cast()super).opEquals(other);
}
private:
Void[immutable(E)] aa_;
}
/// Ditto
auto hashSet(E)(E[] elems...) {
return new HashSet!E(elems);
}
unittest {
import std.exception;
auto a = hashSet(1, 2, 2, 3);
enforce(a.length == 3);
enforce(2 in a);
enforce(5 !in a);
enforce(a.toString().length == 9);
a.remove(2);
enforce(a.length == 2);
enforce(2 !in a);
a.removeAll();
enforce(a.empty);
enforce(a.toString() == "{}");
void delegate() dg;
auto b = hashSet(dg);
static assert(__traits(compiles, b.toString()));
}