Upgrading dependency to Thrift 0.12.0

This commit is contained in:
Renan DelValle 2018-11-27 18:03:50 -08:00
parent 3e4590dcc0
commit 356978cb42
No known key found for this signature in database
GPG key ID: C240AD6D6F443EC9
1302 changed files with 101701 additions and 26784 deletions

45
vendor/git.apache.org/thrift.git/lib/rs/src/autogen.rs generated vendored Normal file
View file

@ -0,0 +1,45 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
//! Thrift compiler auto-generated support.
//!
//!
//! Types and functions used internally by the Thrift compiler's Rust plugin
//! to implement required functionality. Users should never have to use code
//! in this module directly.
use protocol::{TInputProtocol, TOutputProtocol};
/// Specifies the minimum functionality an auto-generated client should provide
/// to communicate with a Thrift server.
pub trait TThriftClient {
/// Returns the input protocol used to read serialized Thrift messages
/// from the Thrift server.
fn i_prot_mut(&mut self) -> &mut TInputProtocol;
/// Returns the output protocol used to write serialized Thrift messages
/// to the Thrift server.
fn o_prot_mut(&mut self) -> &mut TOutputProtocol;
/// Returns the sequence number of the last message written to the Thrift
/// server. Returns `0` if no messages have been written. Sequence
/// numbers should *never* be negative, and this method returns an `i32`
/// simply because the Thrift protocol encodes sequence numbers as `i32` on
/// the wire.
fn sequence_number(&self) -> i32; // FIXME: consider returning a u32
/// Increments the sequence number, indicating that a message with that
/// number has been sent to the Thrift server.
fn increment_sequence_number(&mut self) -> i32;
}

712
vendor/git.apache.org/thrift.git/lib/rs/src/errors.rs generated vendored Normal file
View file

@ -0,0 +1,712 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
use std::convert::{From, Into};
use std::error::Error as StdError;
use std::fmt::{Debug, Display, Formatter};
use std::{error, fmt, io, string};
use try_from::TryFrom;
use protocol::{TFieldIdentifier, TInputProtocol, TOutputProtocol, TStructIdentifier, TType};
// FIXME: should all my error structs impl error::Error as well?
// FIXME: should all fields in TransportError, ProtocolError and ApplicationError be optional?
/// Error type returned by all runtime library functions.
///
/// `thrift::Error` is used throughout this crate as well as in auto-generated
/// Rust code. It consists of four variants defined by convention across Thrift
/// implementations:
///
/// 1. `Transport`: errors encountered while operating on I/O channels
/// 2. `Protocol`: errors encountered during runtime-library processing
/// 3. `Application`: errors encountered within auto-generated code
/// 4. `User`: IDL-defined exception structs
///
/// The `Application` variant also functions as a catch-all: all handler errors
/// are automatically turned into application errors.
///
/// All error variants except `Error::User` take an eponymous struct with two
/// required fields:
///
/// 1. `kind`: variant-specific enum identifying the error sub-type
/// 2. `message`: human-readable error info string
///
/// `kind` is defined by convention while `message` is freeform. If none of the
/// enumerated kinds are suitable use `Unknown`.
///
/// To simplify error creation convenience constructors are defined for all
/// variants, and conversions from their structs (`thrift::TransportError`,
/// `thrift::ProtocolError` and `thrift::ApplicationError` into `thrift::Error`.
///
/// # Examples
///
/// Create a `TransportError`.
///
/// ```
/// use thrift;
/// use thrift::{TransportError, TransportErrorKind};
///
/// // explicit
/// let err0: thrift::Result<()> = Err(
/// thrift::Error::Transport(
/// TransportError {
/// kind: TransportErrorKind::TimedOut,
/// message: format!("connection to server timed out")
/// }
/// )
/// );
///
/// // use conversion
/// let err1: thrift::Result<()> = Err(
/// thrift::Error::from(
/// TransportError {
/// kind: TransportErrorKind::TimedOut,
/// message: format!("connection to server timed out")
/// }
/// )
/// );
///
/// // use struct constructor
/// let err2: thrift::Result<()> = Err(
/// thrift::Error::Transport(
/// TransportError::new(
/// TransportErrorKind::TimedOut,
/// "connection to server timed out"
/// )
/// )
/// );
///
///
/// // use error variant constructor
/// let err3: thrift::Result<()> = Err(
/// thrift::new_transport_error(
/// TransportErrorKind::TimedOut,
/// "connection to server timed out"
/// )
/// );
/// ```
///
/// Create an error from a string.
///
/// ```
/// use thrift;
/// use thrift::{ApplicationError, ApplicationErrorKind};
///
/// // we just use `From::from` to convert a `String` into a `thrift::Error`
/// let err0: thrift::Result<()> = Err(
/// thrift::Error::from("This is an error")
/// );
///
/// // err0 is equivalent to...
/// let err1: thrift::Result<()> = Err(
/// thrift::Error::Application(
/// ApplicationError {
/// kind: ApplicationErrorKind::Unknown,
/// message: format!("This is an error")
/// }
/// )
/// );
/// ```
///
/// Return an IDL-defined exception.
///
/// ```text
/// // Thrift IDL exception definition.
/// exception Xception {
/// 1: i32 errorCode,
/// 2: string message
/// }
/// ```
///
/// ```
/// use std::convert::From;
/// use std::error::Error;
/// use std::fmt;
/// use std::fmt::{Display, Formatter};
///
/// // auto-generated by the Thrift compiler
/// #[derive(Clone, Debug, Eq, Ord, PartialEq, PartialOrd)]
/// pub struct Xception {
/// pub error_code: Option<i32>,
/// pub message: Option<String>,
/// }
///
/// // auto-generated by the Thrift compiler
/// impl Error for Xception {
/// fn description(&self) -> &str {
/// "remote service threw Xception"
/// }
/// }
///
/// // auto-generated by the Thrift compiler
/// impl From<Xception> for thrift::Error {
/// fn from(e: Xception) -> Self {
/// thrift::Error::User(Box::new(e))
/// }
/// }
///
/// // auto-generated by the Thrift compiler
/// impl Display for Xception {
/// fn fmt(&self, f: &mut Formatter) -> fmt::Result {
/// self.description().fmt(f)
/// }
/// }
///
/// // in user code...
/// let err: thrift::Result<()> = Err(
/// thrift::Error::from(Xception { error_code: Some(1), message: None })
/// );
/// ```
pub enum Error {
/// Errors encountered while operating on I/O channels.
///
/// These include *connection closed* and *bind failure*.
Transport(TransportError),
/// Errors encountered during runtime-library processing.
///
/// These include *message too large* and *unsupported protocol version*.
Protocol(ProtocolError),
/// Errors encountered within auto-generated code, or when incoming
/// or outgoing messages violate the Thrift spec.
///
/// These include *out-of-order messages* and *missing required struct
/// fields*.
///
/// This variant also functions as a catch-all: errors from handler
/// functions are automatically returned as an `ApplicationError`.
Application(ApplicationError),
/// IDL-defined exception structs.
User(Box<error::Error + Sync + Send>),
}
impl Error {
/// Create an `ApplicationError` from its wire representation.
///
/// Application code **should never** call this method directly.
pub fn read_application_error_from_in_protocol(i: &mut TInputProtocol,)
-> ::Result<ApplicationError> {
let mut message = "general remote error".to_owned();
let mut kind = ApplicationErrorKind::Unknown;
i.read_struct_begin()?;
loop {
let field_ident = i.read_field_begin()?;
if field_ident.field_type == TType::Stop {
break;
}
let id = field_ident
.id
.expect("sender should always specify id for non-STOP field");
match id {
1 => {
let remote_message = i.read_string()?;
i.read_field_end()?;
message = remote_message;
}
2 => {
let remote_type_as_int = i.read_i32()?;
let remote_kind: ApplicationErrorKind =
TryFrom::try_from(remote_type_as_int)
.unwrap_or(ApplicationErrorKind::Unknown);
i.read_field_end()?;
kind = remote_kind;
}
_ => {
i.skip(field_ident.field_type)?;
}
}
}
i.read_struct_end()?;
Ok(
ApplicationError {
kind: kind,
message: message,
},
)
}
/// Convert an `ApplicationError` into its wire representation and write
/// it to the remote.
///
/// Application code **should never** call this method directly.
pub fn write_application_error_to_out_protocol(
e: &ApplicationError,
o: &mut TOutputProtocol,
) -> ::Result<()> {
o.write_struct_begin(&TStructIdentifier { name: "TApplicationException".to_owned() },)?;
let message_field = TFieldIdentifier::new("message", TType::String, 1);
let type_field = TFieldIdentifier::new("type", TType::I32, 2);
o.write_field_begin(&message_field)?;
o.write_string(&e.message)?;
o.write_field_end()?;
o.write_field_begin(&type_field)?;
o.write_i32(e.kind as i32)?;
o.write_field_end()?;
o.write_field_stop()?;
o.write_struct_end()?;
o.flush()
}
}
impl error::Error for Error {
fn description(&self) -> &str {
match *self {
Error::Transport(ref e) => TransportError::description(e),
Error::Protocol(ref e) => ProtocolError::description(e),
Error::Application(ref e) => ApplicationError::description(e),
Error::User(ref e) => e.description(),
}
}
}
impl Debug for Error {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
match *self {
Error::Transport(ref e) => Debug::fmt(e, f),
Error::Protocol(ref e) => Debug::fmt(e, f),
Error::Application(ref e) => Debug::fmt(e, f),
Error::User(ref e) => Debug::fmt(e, f),
}
}
}
impl Display for Error {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
match *self {
Error::Transport(ref e) => Display::fmt(e, f),
Error::Protocol(ref e) => Display::fmt(e, f),
Error::Application(ref e) => Display::fmt(e, f),
Error::User(ref e) => Display::fmt(e, f),
}
}
}
impl From<String> for Error {
fn from(s: String) -> Self {
Error::Application(
ApplicationError {
kind: ApplicationErrorKind::Unknown,
message: s,
},
)
}
}
impl<'a> From<&'a str> for Error {
fn from(s: &'a str) -> Self {
Error::Application(
ApplicationError {
kind: ApplicationErrorKind::Unknown,
message: String::from(s),
},
)
}
}
impl From<TransportError> for Error {
fn from(e: TransportError) -> Self {
Error::Transport(e)
}
}
impl From<ProtocolError> for Error {
fn from(e: ProtocolError) -> Self {
Error::Protocol(e)
}
}
impl From<ApplicationError> for Error {
fn from(e: ApplicationError) -> Self {
Error::Application(e)
}
}
/// Create a new `Error` instance of type `Transport` that wraps a
/// `TransportError`.
pub fn new_transport_error<S: Into<String>>(kind: TransportErrorKind, message: S) -> Error {
Error::Transport(TransportError::new(kind, message))
}
/// Information about I/O errors.
#[derive(Debug, Eq, PartialEq)]
pub struct TransportError {
/// I/O error variant.
///
/// If a specific `TransportErrorKind` does not apply use
/// `TransportErrorKind::Unknown`.
pub kind: TransportErrorKind,
/// Human-readable error message.
pub message: String,
}
impl TransportError {
/// Create a new `TransportError`.
pub fn new<S: Into<String>>(kind: TransportErrorKind, message: S) -> TransportError {
TransportError {
kind: kind,
message: message.into(),
}
}
}
/// I/O error categories.
///
/// This list may grow, and it is not recommended to match against it.
#[derive(Clone, Copy, Eq, Debug, PartialEq)]
pub enum TransportErrorKind {
/// Catch-all I/O error.
Unknown = 0,
/// An I/O operation was attempted when the transport channel was not open.
NotOpen = 1,
/// The transport channel cannot be opened because it was opened previously.
AlreadyOpen = 2,
/// An I/O operation timed out.
TimedOut = 3,
/// A read could not complete because no bytes were available.
EndOfFile = 4,
/// An invalid (buffer/message) size was requested or received.
NegativeSize = 5,
/// Too large a buffer or message size was requested or received.
SizeLimit = 6,
}
impl TransportError {
fn description(&self) -> &str {
match self.kind {
TransportErrorKind::Unknown => "transport error",
TransportErrorKind::NotOpen => "not open",
TransportErrorKind::AlreadyOpen => "already open",
TransportErrorKind::TimedOut => "timed out",
TransportErrorKind::EndOfFile => "end of file",
TransportErrorKind::NegativeSize => "negative size message",
TransportErrorKind::SizeLimit => "message too long",
}
}
}
impl Display for TransportError {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(f, "{}", self.description())
}
}
impl TryFrom<i32> for TransportErrorKind {
type Err = Error;
fn try_from(from: i32) -> Result<Self, Self::Err> {
match from {
0 => Ok(TransportErrorKind::Unknown),
1 => Ok(TransportErrorKind::NotOpen),
2 => Ok(TransportErrorKind::AlreadyOpen),
3 => Ok(TransportErrorKind::TimedOut),
4 => Ok(TransportErrorKind::EndOfFile),
5 => Ok(TransportErrorKind::NegativeSize),
6 => Ok(TransportErrorKind::SizeLimit),
_ => {
Err(
Error::Protocol(
ProtocolError {
kind: ProtocolErrorKind::Unknown,
message: format!("cannot convert {} to TransportErrorKind", from),
},
),
)
}
}
}
}
impl From<io::Error> for Error {
fn from(err: io::Error) -> Self {
match err.kind() {
io::ErrorKind::ConnectionReset |
io::ErrorKind::ConnectionRefused |
io::ErrorKind::NotConnected => {
Error::Transport(
TransportError {
kind: TransportErrorKind::NotOpen,
message: err.description().to_owned(),
},
)
}
io::ErrorKind::AlreadyExists => {
Error::Transport(
TransportError {
kind: TransportErrorKind::AlreadyOpen,
message: err.description().to_owned(),
},
)
}
io::ErrorKind::TimedOut => {
Error::Transport(
TransportError {
kind: TransportErrorKind::TimedOut,
message: err.description().to_owned(),
},
)
}
io::ErrorKind::UnexpectedEof => {
Error::Transport(
TransportError {
kind: TransportErrorKind::EndOfFile,
message: err.description().to_owned(),
},
)
}
_ => {
Error::Transport(
TransportError {
kind: TransportErrorKind::Unknown,
message: err.description().to_owned(), // FIXME: use io error's debug string
},
)
}
}
}
}
impl From<string::FromUtf8Error> for Error {
fn from(err: string::FromUtf8Error) -> Self {
Error::Protocol(
ProtocolError {
kind: ProtocolErrorKind::InvalidData,
message: err.description().to_owned(), // FIXME: use fmt::Error's debug string
},
)
}
}
/// Create a new `Error` instance of type `Protocol` that wraps a
/// `ProtocolError`.
pub fn new_protocol_error<S: Into<String>>(kind: ProtocolErrorKind, message: S) -> Error {
Error::Protocol(ProtocolError::new(kind, message))
}
/// Information about errors that occur in the runtime library.
#[derive(Debug, Eq, PartialEq)]
pub struct ProtocolError {
/// Protocol error variant.
///
/// If a specific `ProtocolErrorKind` does not apply use
/// `ProtocolErrorKind::Unknown`.
pub kind: ProtocolErrorKind,
/// Human-readable error message.
pub message: String,
}
impl ProtocolError {
/// Create a new `ProtocolError`.
pub fn new<S: Into<String>>(kind: ProtocolErrorKind, message: S) -> ProtocolError {
ProtocolError {
kind: kind,
message: message.into(),
}
}
}
/// Runtime library error categories.
///
/// This list may grow, and it is not recommended to match against it.
#[derive(Clone, Copy, Eq, Debug, PartialEq)]
pub enum ProtocolErrorKind {
/// Catch-all runtime-library error.
Unknown = 0,
/// An invalid argument was supplied to a library function, or invalid data
/// was received from a Thrift endpoint.
InvalidData = 1,
/// An invalid size was received in an encoded field.
NegativeSize = 2,
/// Thrift message or field was too long.
SizeLimit = 3,
/// Unsupported or unknown Thrift protocol version.
BadVersion = 4,
/// Unsupported Thrift protocol, server or field type.
NotImplemented = 5,
/// Reached the maximum nested depth to which an encoded Thrift field could
/// be skipped.
DepthLimit = 6,
}
impl ProtocolError {
fn description(&self) -> &str {
match self.kind {
ProtocolErrorKind::Unknown => "protocol error",
ProtocolErrorKind::InvalidData => "bad data",
ProtocolErrorKind::NegativeSize => "negative message size",
ProtocolErrorKind::SizeLimit => "message too long",
ProtocolErrorKind::BadVersion => "invalid thrift version",
ProtocolErrorKind::NotImplemented => "not implemented",
ProtocolErrorKind::DepthLimit => "maximum skip depth reached",
}
}
}
impl Display for ProtocolError {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(f, "{}", self.description())
}
}
impl TryFrom<i32> for ProtocolErrorKind {
type Err = Error;
fn try_from(from: i32) -> Result<Self, Self::Err> {
match from {
0 => Ok(ProtocolErrorKind::Unknown),
1 => Ok(ProtocolErrorKind::InvalidData),
2 => Ok(ProtocolErrorKind::NegativeSize),
3 => Ok(ProtocolErrorKind::SizeLimit),
4 => Ok(ProtocolErrorKind::BadVersion),
5 => Ok(ProtocolErrorKind::NotImplemented),
6 => Ok(ProtocolErrorKind::DepthLimit),
_ => {
Err(
Error::Protocol(
ProtocolError {
kind: ProtocolErrorKind::Unknown,
message: format!("cannot convert {} to ProtocolErrorKind", from),
},
),
)
}
}
}
}
/// Create a new `Error` instance of type `Application` that wraps an
/// `ApplicationError`.
pub fn new_application_error<S: Into<String>>(kind: ApplicationErrorKind, message: S) -> Error {
Error::Application(ApplicationError::new(kind, message))
}
/// Information about errors in auto-generated code or in user-implemented
/// service handlers.
#[derive(Debug, Eq, PartialEq)]
pub struct ApplicationError {
/// Application error variant.
///
/// If a specific `ApplicationErrorKind` does not apply use
/// `ApplicationErrorKind::Unknown`.
pub kind: ApplicationErrorKind,
/// Human-readable error message.
pub message: String,
}
impl ApplicationError {
/// Create a new `ApplicationError`.
pub fn new<S: Into<String>>(kind: ApplicationErrorKind, message: S) -> ApplicationError {
ApplicationError {
kind: kind,
message: message.into(),
}
}
}
/// Auto-generated or user-implemented code error categories.
///
/// This list may grow, and it is not recommended to match against it.
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum ApplicationErrorKind {
/// Catch-all application error.
Unknown = 0,
/// Made service call to an unknown service method.
UnknownMethod = 1,
/// Received an unknown Thrift message type. That is, not one of the
/// `thrift::protocol::TMessageType` variants.
InvalidMessageType = 2,
/// Method name in a service reply does not match the name of the
/// receiving service method.
WrongMethodName = 3,
/// Received an out-of-order Thrift message.
BadSequenceId = 4,
/// Service reply is missing required fields.
MissingResult = 5,
/// Auto-generated code failed unexpectedly.
InternalError = 6,
/// Thrift protocol error. When possible use `Error::ProtocolError` with a
/// specific `ProtocolErrorKind` instead.
ProtocolError = 7,
/// *Unknown*. Included only for compatibility with existing Thrift implementations.
InvalidTransform = 8, // ??
/// Thrift endpoint requested, or is using, an unsupported encoding.
InvalidProtocol = 9, // ??
/// Thrift endpoint requested, or is using, an unsupported auto-generated client type.
UnsupportedClientType = 10, // ??
}
impl ApplicationError {
fn description(&self) -> &str {
match self.kind {
ApplicationErrorKind::Unknown => "service error",
ApplicationErrorKind::UnknownMethod => "unknown service method",
ApplicationErrorKind::InvalidMessageType => "wrong message type received",
ApplicationErrorKind::WrongMethodName => "unknown method reply received",
ApplicationErrorKind::BadSequenceId => "out of order sequence id",
ApplicationErrorKind::MissingResult => "missing method result",
ApplicationErrorKind::InternalError => "remote service threw exception",
ApplicationErrorKind::ProtocolError => "protocol error",
ApplicationErrorKind::InvalidTransform => "invalid transform",
ApplicationErrorKind::InvalidProtocol => "invalid protocol requested",
ApplicationErrorKind::UnsupportedClientType => "unsupported protocol client",
}
}
}
impl Display for ApplicationError {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(f, "{}", self.description())
}
}
impl TryFrom<i32> for ApplicationErrorKind {
type Err = Error;
fn try_from(from: i32) -> Result<Self, Self::Err> {
match from {
0 => Ok(ApplicationErrorKind::Unknown),
1 => Ok(ApplicationErrorKind::UnknownMethod),
2 => Ok(ApplicationErrorKind::InvalidMessageType),
3 => Ok(ApplicationErrorKind::WrongMethodName),
4 => Ok(ApplicationErrorKind::BadSequenceId),
5 => Ok(ApplicationErrorKind::MissingResult),
6 => Ok(ApplicationErrorKind::InternalError),
7 => Ok(ApplicationErrorKind::ProtocolError),
8 => Ok(ApplicationErrorKind::InvalidTransform),
9 => Ok(ApplicationErrorKind::InvalidProtocol),
10 => Ok(ApplicationErrorKind::UnsupportedClientType),
_ => {
Err(
Error::Application(
ApplicationError {
kind: ApplicationErrorKind::Unknown,
message: format!("cannot convert {} to ApplicationErrorKind", from),
},
),
)
}
}
}
}

89
vendor/git.apache.org/thrift.git/lib/rs/src/lib.rs generated vendored Normal file
View file

@ -0,0 +1,89 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
//! Rust runtime library for the Apache Thrift RPC system.
//!
//! This crate implements the components required to build a working
//! Thrift server and client. It is divided into the following modules:
//!
//! 1. errors
//! 2. protocol
//! 3. transport
//! 4. server
//! 5. autogen
//!
//! The modules are layered as shown in the diagram below. The `autogen'd`
//! layer is generated by the Thrift compiler's Rust plugin. It uses the
//! types and functions defined in this crate to serialize and deserialize
//! messages and implement RPC. Users interact with these types and services
//! by writing their own code that uses the auto-generated clients and
//! servers.
//!
//! ```text
//! +-----------+
//! | user app |
//! +-----------+
//! | autogen'd | (uses errors, autogen)
//! +-----------+
//! | protocol |
//! +-----------+
//! | transport |
//! +-----------+
//! ```
#![crate_type = "lib"]
#![doc(test(attr(allow(unused_variables), deny(warnings))))]
extern crate byteorder;
extern crate integer_encoding;
extern crate threadpool;
extern crate try_from;
#[macro_use]
extern crate log;
// NOTE: this macro has to be defined before any modules. See:
// https://danielkeep.github.io/quick-intro-to-macros.html#some-more-gotchas
/// Assert that an expression returning a `Result` is a success. If it is,
/// return the value contained in the result, i.e. `expr.unwrap()`.
#[cfg(test)]
macro_rules! assert_success {
($e: expr) => {
{
let res = $e;
assert!(res.is_ok());
res.unwrap()
}
}
}
pub mod protocol;
pub mod server;
pub mod transport;
mod errors;
pub use errors::*;
mod autogen;
pub use autogen::*;
/// Result type returned by all runtime library functions.
///
/// As is convention this is a typedef of `std::result::Result`
/// with `E` defined as the `thrift::Error` type.
pub type Result<T> = std::result::Result<T, self::Error>;

View file

@ -0,0 +1,919 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
use byteorder::{BigEndian, ByteOrder, ReadBytesExt, WriteBytesExt};
use std::convert::From;
use try_from::TryFrom;
use {ProtocolError, ProtocolErrorKind};
use transport::{TReadTransport, TWriteTransport};
use super::{TFieldIdentifier, TInputProtocol, TInputProtocolFactory, TListIdentifier,
TMapIdentifier, TMessageIdentifier, TMessageType};
use super::{TOutputProtocol, TOutputProtocolFactory, TSetIdentifier, TStructIdentifier, TType};
const BINARY_PROTOCOL_VERSION_1: u32 = 0x80010000;
/// Read messages encoded in the Thrift simple binary encoding.
///
/// There are two available modes: `strict` and `non-strict`, where the
/// `non-strict` version does not check for the protocol version in the
/// received message header.
///
/// # Examples
///
/// Create and use a `TBinaryInputProtocol`.
///
/// ```no_run
/// use thrift::protocol::{TBinaryInputProtocol, TInputProtocol};
/// use thrift::transport::TTcpChannel;
///
/// let mut channel = TTcpChannel::new();
/// channel.open("localhost:9090").unwrap();
///
/// let mut protocol = TBinaryInputProtocol::new(channel, true);
///
/// let recvd_bool = protocol.read_bool().unwrap();
/// let recvd_string = protocol.read_string().unwrap();
/// ```
#[derive(Debug)]
pub struct TBinaryInputProtocol<T>
where
T: TReadTransport,
{
strict: bool,
pub transport: T, // FIXME: shouldn't be public
}
impl<'a, T> TBinaryInputProtocol<T>
where
T: TReadTransport,
{
/// Create a `TBinaryInputProtocol` that reads bytes from `transport`.
///
/// Set `strict` to `true` if all incoming messages contain the protocol
/// version number in the protocol header.
pub fn new(transport: T, strict: bool) -> TBinaryInputProtocol<T> {
TBinaryInputProtocol {
strict: strict,
transport: transport,
}
}
}
impl<T> TInputProtocol for TBinaryInputProtocol<T>
where
T: TReadTransport,
{
#[cfg_attr(feature = "cargo-clippy", allow(collapsible_if))]
fn read_message_begin(&mut self) -> ::Result<TMessageIdentifier> {
let mut first_bytes = vec![0; 4];
self.transport.read_exact(&mut first_bytes[..])?;
// the thrift version header is intentionally negative
// so the first check we'll do is see if the sign bit is set
// and if so - assume it's the protocol-version header
if first_bytes[0] >= 8 {
// apparently we got a protocol-version header - check
// it, and if it matches, read the rest of the fields
if first_bytes[0..2] != [0x80, 0x01] {
Err(
::Error::Protocol(
ProtocolError {
kind: ProtocolErrorKind::BadVersion,
message: format!("received bad version: {:?}", &first_bytes[0..2]),
},
),
)
} else {
let message_type: TMessageType = TryFrom::try_from(first_bytes[3])?;
let name = self.read_string()?;
let sequence_number = self.read_i32()?;
Ok(TMessageIdentifier::new(name, message_type, sequence_number))
}
} else {
// apparently we didn't get a protocol-version header,
// which happens if the sender is not using the strict protocol
if self.strict {
// we're in strict mode however, and that always
// requires the protocol-version header to be written first
Err(
::Error::Protocol(
ProtocolError {
kind: ProtocolErrorKind::BadVersion,
message: format!("received bad version: {:?}", &first_bytes[0..2]),
},
),
)
} else {
// in the non-strict version the first message field
// is the message name. strings (byte arrays) are length-prefixed,
// so we've just read the length in the first 4 bytes
let name_size = BigEndian::read_i32(&first_bytes) as usize;
let mut name_buf: Vec<u8> = Vec::with_capacity(name_size);
self.transport.read_exact(&mut name_buf)?;
let name = String::from_utf8(name_buf)?;
// read the rest of the fields
let message_type: TMessageType = self.read_byte().and_then(TryFrom::try_from)?;
let sequence_number = self.read_i32()?;
Ok(TMessageIdentifier::new(name, message_type, sequence_number))
}
}
}
fn read_message_end(&mut self) -> ::Result<()> {
Ok(())
}
fn read_struct_begin(&mut self) -> ::Result<Option<TStructIdentifier>> {
Ok(None)
}
fn read_struct_end(&mut self) -> ::Result<()> {
Ok(())
}
fn read_field_begin(&mut self) -> ::Result<TFieldIdentifier> {
let field_type_byte = self.read_byte()?;
let field_type = field_type_from_u8(field_type_byte)?;
let id = match field_type {
TType::Stop => Ok(0),
_ => self.read_i16(),
}?;
Ok(TFieldIdentifier::new::<Option<String>, String, i16>(None, field_type, id),)
}
fn read_field_end(&mut self) -> ::Result<()> {
Ok(())
}
fn read_bytes(&mut self) -> ::Result<Vec<u8>> {
let num_bytes = self.transport.read_i32::<BigEndian>()? as usize;
let mut buf = vec![0u8; num_bytes];
self.transport
.read_exact(&mut buf)
.map(|_| buf)
.map_err(From::from)
}
fn read_bool(&mut self) -> ::Result<bool> {
let b = self.read_i8()?;
match b {
0 => Ok(false),
_ => Ok(true),
}
}
fn read_i8(&mut self) -> ::Result<i8> {
self.transport.read_i8().map_err(From::from)
}
fn read_i16(&mut self) -> ::Result<i16> {
self.transport
.read_i16::<BigEndian>()
.map_err(From::from)
}
fn read_i32(&mut self) -> ::Result<i32> {
self.transport
.read_i32::<BigEndian>()
.map_err(From::from)
}
fn read_i64(&mut self) -> ::Result<i64> {
self.transport
.read_i64::<BigEndian>()
.map_err(From::from)
}
fn read_double(&mut self) -> ::Result<f64> {
self.transport
.read_f64::<BigEndian>()
.map_err(From::from)
}
fn read_string(&mut self) -> ::Result<String> {
let bytes = self.read_bytes()?;
String::from_utf8(bytes).map_err(From::from)
}
fn read_list_begin(&mut self) -> ::Result<TListIdentifier> {
let element_type: TType = self.read_byte().and_then(field_type_from_u8)?;
let size = self.read_i32()?;
Ok(TListIdentifier::new(element_type, size))
}
fn read_list_end(&mut self) -> ::Result<()> {
Ok(())
}
fn read_set_begin(&mut self) -> ::Result<TSetIdentifier> {
let element_type: TType = self.read_byte().and_then(field_type_from_u8)?;
let size = self.read_i32()?;
Ok(TSetIdentifier::new(element_type, size))
}
fn read_set_end(&mut self) -> ::Result<()> {
Ok(())
}
fn read_map_begin(&mut self) -> ::Result<TMapIdentifier> {
let key_type: TType = self.read_byte().and_then(field_type_from_u8)?;
let value_type: TType = self.read_byte().and_then(field_type_from_u8)?;
let size = self.read_i32()?;
Ok(TMapIdentifier::new(key_type, value_type, size))
}
fn read_map_end(&mut self) -> ::Result<()> {
Ok(())
}
// utility
//
fn read_byte(&mut self) -> ::Result<u8> {
self.transport.read_u8().map_err(From::from)
}
}
/// Factory for creating instances of `TBinaryInputProtocol`.
#[derive(Default)]
pub struct TBinaryInputProtocolFactory;
impl TBinaryInputProtocolFactory {
/// Create a `TBinaryInputProtocolFactory`.
pub fn new() -> TBinaryInputProtocolFactory {
TBinaryInputProtocolFactory {}
}
}
impl TInputProtocolFactory for TBinaryInputProtocolFactory {
fn create(&self, transport: Box<TReadTransport + Send>) -> Box<TInputProtocol + Send> {
Box::new(TBinaryInputProtocol::new(transport, true))
}
}
/// Write messages using the Thrift simple binary encoding.
///
/// There are two available modes: `strict` and `non-strict`, where the
/// `strict` version writes the protocol version number in the outgoing message
/// header and the `non-strict` version does not.
///
/// # Examples
///
/// Create and use a `TBinaryOutputProtocol`.
///
/// ```no_run
/// use thrift::protocol::{TBinaryOutputProtocol, TOutputProtocol};
/// use thrift::transport::TTcpChannel;
///
/// let mut channel = TTcpChannel::new();
/// channel.open("localhost:9090").unwrap();
///
/// let mut protocol = TBinaryOutputProtocol::new(channel, true);
///
/// protocol.write_bool(true).unwrap();
/// protocol.write_string("test_string").unwrap();
/// ```
#[derive(Debug)]
pub struct TBinaryOutputProtocol<T>
where
T: TWriteTransport,
{
strict: bool,
pub transport: T, // FIXME: do not make public; only public for testing!
}
impl<T> TBinaryOutputProtocol<T>
where
T: TWriteTransport,
{
/// Create a `TBinaryOutputProtocol` that writes bytes to `transport`.
///
/// Set `strict` to `true` if all outgoing messages should contain the
/// protocol version number in the protocol header.
pub fn new(transport: T, strict: bool) -> TBinaryOutputProtocol<T> {
TBinaryOutputProtocol {
strict: strict,
transport: transport,
}
}
}
impl<T> TOutputProtocol for TBinaryOutputProtocol<T>
where
T: TWriteTransport,
{
fn write_message_begin(&mut self, identifier: &TMessageIdentifier) -> ::Result<()> {
if self.strict {
let message_type: u8 = identifier.message_type.into();
let header = BINARY_PROTOCOL_VERSION_1 | (message_type as u32);
self.transport.write_u32::<BigEndian>(header)?;
self.write_string(&identifier.name)?;
self.write_i32(identifier.sequence_number)
} else {
self.write_string(&identifier.name)?;
self.write_byte(identifier.message_type.into())?;
self.write_i32(identifier.sequence_number)
}
}
fn write_message_end(&mut self) -> ::Result<()> {
Ok(())
}
fn write_struct_begin(&mut self, _: &TStructIdentifier) -> ::Result<()> {
Ok(())
}
fn write_struct_end(&mut self) -> ::Result<()> {
Ok(())
}
fn write_field_begin(&mut self, identifier: &TFieldIdentifier) -> ::Result<()> {
if identifier.id.is_none() && identifier.field_type != TType::Stop {
return Err(
::Error::Protocol(
ProtocolError {
kind: ProtocolErrorKind::Unknown,
message: format!(
"cannot write identifier {:?} without sequence number",
&identifier
),
},
),
);
}
self.write_byte(field_type_to_u8(identifier.field_type))?;
if let Some(id) = identifier.id {
self.write_i16(id)
} else {
Ok(())
}
}
fn write_field_end(&mut self) -> ::Result<()> {
Ok(())
}
fn write_field_stop(&mut self) -> ::Result<()> {
self.write_byte(field_type_to_u8(TType::Stop))
}
fn write_bytes(&mut self, b: &[u8]) -> ::Result<()> {
self.write_i32(b.len() as i32)?;
self.transport.write_all(b).map_err(From::from)
}
fn write_bool(&mut self, b: bool) -> ::Result<()> {
if b {
self.write_i8(1)
} else {
self.write_i8(0)
}
}
fn write_i8(&mut self, i: i8) -> ::Result<()> {
self.transport.write_i8(i).map_err(From::from)
}
fn write_i16(&mut self, i: i16) -> ::Result<()> {
self.transport
.write_i16::<BigEndian>(i)
.map_err(From::from)
}
fn write_i32(&mut self, i: i32) -> ::Result<()> {
self.transport
.write_i32::<BigEndian>(i)
.map_err(From::from)
}
fn write_i64(&mut self, i: i64) -> ::Result<()> {
self.transport
.write_i64::<BigEndian>(i)
.map_err(From::from)
}
fn write_double(&mut self, d: f64) -> ::Result<()> {
self.transport
.write_f64::<BigEndian>(d)
.map_err(From::from)
}
fn write_string(&mut self, s: &str) -> ::Result<()> {
self.write_bytes(s.as_bytes())
}
fn write_list_begin(&mut self, identifier: &TListIdentifier) -> ::Result<()> {
self.write_byte(field_type_to_u8(identifier.element_type))?;
self.write_i32(identifier.size)
}
fn write_list_end(&mut self) -> ::Result<()> {
Ok(())
}
fn write_set_begin(&mut self, identifier: &TSetIdentifier) -> ::Result<()> {
self.write_byte(field_type_to_u8(identifier.element_type))?;
self.write_i32(identifier.size)
}
fn write_set_end(&mut self) -> ::Result<()> {
Ok(())
}
fn write_map_begin(&mut self, identifier: &TMapIdentifier) -> ::Result<()> {
let key_type = identifier
.key_type
.expect("map identifier to write should contain key type");
self.write_byte(field_type_to_u8(key_type))?;
let val_type = identifier
.value_type
.expect("map identifier to write should contain value type");
self.write_byte(field_type_to_u8(val_type))?;
self.write_i32(identifier.size)
}
fn write_map_end(&mut self) -> ::Result<()> {
Ok(())
}
fn flush(&mut self) -> ::Result<()> {
self.transport.flush().map_err(From::from)
}
// utility
//
fn write_byte(&mut self, b: u8) -> ::Result<()> {
self.transport.write_u8(b).map_err(From::from)
}
}
/// Factory for creating instances of `TBinaryOutputProtocol`.
#[derive(Default)]
pub struct TBinaryOutputProtocolFactory;
impl TBinaryOutputProtocolFactory {
/// Create a `TBinaryOutputProtocolFactory`.
pub fn new() -> TBinaryOutputProtocolFactory {
TBinaryOutputProtocolFactory {}
}
}
impl TOutputProtocolFactory for TBinaryOutputProtocolFactory {
fn create(&self, transport: Box<TWriteTransport + Send>) -> Box<TOutputProtocol + Send> {
Box::new(TBinaryOutputProtocol::new(transport, true))
}
}
fn field_type_to_u8(field_type: TType) -> u8 {
match field_type {
TType::Stop => 0x00,
TType::Void => 0x01,
TType::Bool => 0x02,
TType::I08 => 0x03, // equivalent to TType::Byte
TType::Double => 0x04,
TType::I16 => 0x06,
TType::I32 => 0x08,
TType::I64 => 0x0A,
TType::String | TType::Utf7 => 0x0B,
TType::Struct => 0x0C,
TType::Map => 0x0D,
TType::Set => 0x0E,
TType::List => 0x0F,
TType::Utf8 => 0x10,
TType::Utf16 => 0x11,
}
}
fn field_type_from_u8(b: u8) -> ::Result<TType> {
match b {
0x00 => Ok(TType::Stop),
0x01 => Ok(TType::Void),
0x02 => Ok(TType::Bool),
0x03 => Ok(TType::I08), // Equivalent to TType::Byte
0x04 => Ok(TType::Double),
0x06 => Ok(TType::I16),
0x08 => Ok(TType::I32),
0x0A => Ok(TType::I64),
0x0B => Ok(TType::String), // technically, also a UTF7, but we'll treat it as string
0x0C => Ok(TType::Struct),
0x0D => Ok(TType::Map),
0x0E => Ok(TType::Set),
0x0F => Ok(TType::List),
0x10 => Ok(TType::Utf8),
0x11 => Ok(TType::Utf16),
unkn => {
Err(
::Error::Protocol(
ProtocolError {
kind: ProtocolErrorKind::InvalidData,
message: format!("cannot convert {} to TType", unkn),
},
),
)
}
}
}
#[cfg(test)]
mod tests {
use protocol::{TFieldIdentifier, TInputProtocol, TListIdentifier, TMapIdentifier,
TMessageIdentifier, TMessageType, TOutputProtocol, TSetIdentifier,
TStructIdentifier, TType};
use transport::{ReadHalf, TBufferChannel, TIoChannel, WriteHalf};
use super::*;
#[test]
fn must_write_message_call_begin() {
let (_, mut o_prot) = test_objects();
let ident = TMessageIdentifier::new("test", TMessageType::Call, 1);
assert!(o_prot.write_message_begin(&ident).is_ok());
let expected: [u8; 16] = [
0x80,
0x01,
0x00,
0x01,
0x00,
0x00,
0x00,
0x04,
0x74,
0x65,
0x73,
0x74,
0x00,
0x00,
0x00,
0x01,
];
assert_eq_written_bytes!(o_prot, expected);
}
#[test]
fn must_write_message_reply_begin() {
let (_, mut o_prot) = test_objects();
let ident = TMessageIdentifier::new("test", TMessageType::Reply, 10);
assert!(o_prot.write_message_begin(&ident).is_ok());
let expected: [u8; 16] = [
0x80,
0x01,
0x00,
0x02,
0x00,
0x00,
0x00,
0x04,
0x74,
0x65,
0x73,
0x74,
0x00,
0x00,
0x00,
0x0A,
];
assert_eq_written_bytes!(o_prot, expected);
}
#[test]
fn must_round_trip_strict_message_begin() {
let (mut i_prot, mut o_prot) = test_objects();
let sent_ident = TMessageIdentifier::new("test", TMessageType::Call, 1);
assert!(o_prot.write_message_begin(&sent_ident).is_ok());
copy_write_buffer_to_read_buffer!(o_prot);
let received_ident = assert_success!(i_prot.read_message_begin());
assert_eq!(&received_ident, &sent_ident);
}
#[test]
fn must_write_message_end() {
assert_no_write(|o| o.write_message_end());
}
#[test]
fn must_write_struct_begin() {
assert_no_write(|o| o.write_struct_begin(&TStructIdentifier::new("foo")));
}
#[test]
fn must_write_struct_end() {
assert_no_write(|o| o.write_struct_end());
}
#[test]
fn must_write_field_begin() {
let (_, mut o_prot) = test_objects();
assert!(
o_prot
.write_field_begin(&TFieldIdentifier::new("some_field", TType::String, 22))
.is_ok()
);
let expected: [u8; 3] = [0x0B, 0x00, 0x16];
assert_eq_written_bytes!(o_prot, expected);
}
#[test]
fn must_round_trip_field_begin() {
let (mut i_prot, mut o_prot) = test_objects();
let sent_field_ident = TFieldIdentifier::new("foo", TType::I64, 20);
assert!(o_prot.write_field_begin(&sent_field_ident).is_ok());
copy_write_buffer_to_read_buffer!(o_prot);
let expected_ident = TFieldIdentifier {
name: None,
field_type: TType::I64,
id: Some(20),
}; // no name
let received_ident = assert_success!(i_prot.read_field_begin());
assert_eq!(&received_ident, &expected_ident);
}
#[test]
fn must_write_stop_field() {
let (_, mut o_prot) = test_objects();
assert!(o_prot.write_field_stop().is_ok());
let expected: [u8; 1] = [0x00];
assert_eq_written_bytes!(o_prot, expected);
}
#[test]
fn must_round_trip_field_stop() {
let (mut i_prot, mut o_prot) = test_objects();
assert!(o_prot.write_field_stop().is_ok());
copy_write_buffer_to_read_buffer!(o_prot);
let expected_ident = TFieldIdentifier {
name: None,
field_type: TType::Stop,
id: Some(0),
}; // we get id 0
let received_ident = assert_success!(i_prot.read_field_begin());
assert_eq!(&received_ident, &expected_ident);
}
#[test]
fn must_write_field_end() {
assert_no_write(|o| o.write_field_end());
}
#[test]
fn must_write_list_begin() {
let (_, mut o_prot) = test_objects();
assert!(
o_prot
.write_list_begin(&TListIdentifier::new(TType::Bool, 5))
.is_ok()
);
let expected: [u8; 5] = [0x02, 0x00, 0x00, 0x00, 0x05];
assert_eq_written_bytes!(o_prot, expected);
}
#[test]
fn must_round_trip_list_begin() {
let (mut i_prot, mut o_prot) = test_objects();
let ident = TListIdentifier::new(TType::List, 900);
assert!(o_prot.write_list_begin(&ident).is_ok());
copy_write_buffer_to_read_buffer!(o_prot);
let received_ident = assert_success!(i_prot.read_list_begin());
assert_eq!(&received_ident, &ident);
}
#[test]
fn must_write_list_end() {
assert_no_write(|o| o.write_list_end());
}
#[test]
fn must_write_set_begin() {
let (_, mut o_prot) = test_objects();
assert!(
o_prot
.write_set_begin(&TSetIdentifier::new(TType::I16, 7))
.is_ok()
);
let expected: [u8; 5] = [0x06, 0x00, 0x00, 0x00, 0x07];
assert_eq_written_bytes!(o_prot, expected);
}
#[test]
fn must_round_trip_set_begin() {
let (mut i_prot, mut o_prot) = test_objects();
let ident = TSetIdentifier::new(TType::I64, 2000);
assert!(o_prot.write_set_begin(&ident).is_ok());
copy_write_buffer_to_read_buffer!(o_prot);
let received_ident_result = i_prot.read_set_begin();
assert!(received_ident_result.is_ok());
assert_eq!(&received_ident_result.unwrap(), &ident);
}
#[test]
fn must_write_set_end() {
assert_no_write(|o| o.write_set_end());
}
#[test]
fn must_write_map_begin() {
let (_, mut o_prot) = test_objects();
assert!(
o_prot
.write_map_begin(&TMapIdentifier::new(TType::I64, TType::Struct, 32))
.is_ok()
);
let expected: [u8; 6] = [0x0A, 0x0C, 0x00, 0x00, 0x00, 0x20];
assert_eq_written_bytes!(o_prot, expected);
}
#[test]
fn must_round_trip_map_begin() {
let (mut i_prot, mut o_prot) = test_objects();
let ident = TMapIdentifier::new(TType::Map, TType::Set, 100);
assert!(o_prot.write_map_begin(&ident).is_ok());
copy_write_buffer_to_read_buffer!(o_prot);
let received_ident = assert_success!(i_prot.read_map_begin());
assert_eq!(&received_ident, &ident);
}
#[test]
fn must_write_map_end() {
assert_no_write(|o| o.write_map_end());
}
#[test]
fn must_write_bool_true() {
let (_, mut o_prot) = test_objects();
assert!(o_prot.write_bool(true).is_ok());
let expected: [u8; 1] = [0x01];
assert_eq_written_bytes!(o_prot, expected);
}
#[test]
fn must_write_bool_false() {
let (_, mut o_prot) = test_objects();
assert!(o_prot.write_bool(false).is_ok());
let expected: [u8; 1] = [0x00];
assert_eq_written_bytes!(o_prot, expected);
}
#[test]
fn must_read_bool_true() {
let (mut i_prot, _) = test_objects();
set_readable_bytes!(i_prot, &[0x01]);
let read_bool = assert_success!(i_prot.read_bool());
assert_eq!(read_bool, true);
}
#[test]
fn must_read_bool_false() {
let (mut i_prot, _) = test_objects();
set_readable_bytes!(i_prot, &[0x00]);
let read_bool = assert_success!(i_prot.read_bool());
assert_eq!(read_bool, false);
}
#[test]
fn must_allow_any_non_zero_value_to_be_interpreted_as_bool_true() {
let (mut i_prot, _) = test_objects();
set_readable_bytes!(i_prot, &[0xAC]);
let read_bool = assert_success!(i_prot.read_bool());
assert_eq!(read_bool, true);
}
#[test]
fn must_write_bytes() {
let (_, mut o_prot) = test_objects();
let bytes: [u8; 10] = [0x0A, 0xCC, 0xD1, 0x84, 0x99, 0x12, 0xAB, 0xBB, 0x45, 0xDF];
assert!(o_prot.write_bytes(&bytes).is_ok());
let buf = o_prot.transport.write_bytes();
assert_eq!(&buf[0..4], [0x00, 0x00, 0x00, 0x0A]); // length
assert_eq!(&buf[4..], bytes); // actual bytes
}
#[test]
fn must_round_trip_bytes() {
let (mut i_prot, mut o_prot) = test_objects();
let bytes: [u8; 25] = [
0x20,
0xFD,
0x18,
0x84,
0x99,
0x12,
0xAB,
0xBB,
0x45,
0xDF,
0x34,
0xDC,
0x98,
0xA4,
0x6D,
0xF3,
0x99,
0xB4,
0xB7,
0xD4,
0x9C,
0xA5,
0xB3,
0xC9,
0x88,
];
assert!(o_prot.write_bytes(&bytes).is_ok());
copy_write_buffer_to_read_buffer!(o_prot);
let received_bytes = assert_success!(i_prot.read_bytes());
assert_eq!(&received_bytes, &bytes);
}
fn test_objects()
-> (TBinaryInputProtocol<ReadHalf<TBufferChannel>>,
TBinaryOutputProtocol<WriteHalf<TBufferChannel>>)
{
let mem = TBufferChannel::with_capacity(40, 40);
let (r_mem, w_mem) = mem.split().unwrap();
let i_prot = TBinaryInputProtocol::new(r_mem, true);
let o_prot = TBinaryOutputProtocol::new(w_mem, true);
(i_prot, o_prot)
}
fn assert_no_write<F>(mut write_fn: F)
where
F: FnMut(&mut TBinaryOutputProtocol<WriteHalf<TBufferChannel>>) -> ::Result<()>,
{
let (_, mut o_prot) = test_objects();
assert!(write_fn(&mut o_prot).is_ok());
assert_eq!(o_prot.transport.write_bytes().len(), 0);
}
}

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,237 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
use super::{TFieldIdentifier, TListIdentifier, TMapIdentifier, TMessageIdentifier, TMessageType,
TOutputProtocol, TSetIdentifier, TStructIdentifier};
/// `TOutputProtocol` that prefixes the service name to all outgoing Thrift
/// messages.
///
/// A `TMultiplexedOutputProtocol` should be used when multiple Thrift services
/// send messages over a single I/O channel. By prefixing service identifiers
/// to outgoing messages receivers are able to demux them and route them to the
/// appropriate service processor. Rust receivers must use a `TMultiplexedProcessor`
/// to process incoming messages, while other languages must use their
/// corresponding multiplexed processor implementations.
///
/// For example, given a service `TestService` and a service call `test_call`,
/// this implementation would identify messages as originating from
/// `TestService:test_call`.
///
/// # Examples
///
/// Create and use a `TMultiplexedOutputProtocol`.
///
/// ```no_run
/// use thrift::protocol::{TMessageIdentifier, TMessageType, TOutputProtocol};
/// use thrift::protocol::{TBinaryOutputProtocol, TMultiplexedOutputProtocol};
/// use thrift::transport::TTcpChannel;
///
/// let mut channel = TTcpChannel::new();
/// channel.open("localhost:9090").unwrap();
///
/// let protocol = TBinaryOutputProtocol::new(channel, true);
/// let mut protocol = TMultiplexedOutputProtocol::new("service_name", protocol);
///
/// let ident = TMessageIdentifier::new("svc_call", TMessageType::Call, 1);
/// protocol.write_message_begin(&ident).unwrap();
/// ```
#[derive(Debug)]
pub struct TMultiplexedOutputProtocol<P>
where
P: TOutputProtocol,
{
service_name: String,
inner: P,
}
impl<P> TMultiplexedOutputProtocol<P>
where
P: TOutputProtocol,
{
/// Create a `TMultiplexedOutputProtocol` that identifies outgoing messages
/// as originating from a service named `service_name` and sends them over
/// the `wrapped` `TOutputProtocol`. Outgoing messages are encoded and sent
/// by `wrapped`, not by this instance.
pub fn new(service_name: &str, wrapped: P) -> TMultiplexedOutputProtocol<P> {
TMultiplexedOutputProtocol {
service_name: service_name.to_owned(),
inner: wrapped,
}
}
}
// FIXME: avoid passthrough methods
impl<P> TOutputProtocol for TMultiplexedOutputProtocol<P>
where
P: TOutputProtocol,
{
fn write_message_begin(&mut self, identifier: &TMessageIdentifier) -> ::Result<()> {
match identifier.message_type { // FIXME: is there a better way to override identifier here?
TMessageType::Call | TMessageType::OneWay => {
let identifier = TMessageIdentifier {
name: format!("{}:{}", self.service_name, identifier.name),
..*identifier
};
self.inner.write_message_begin(&identifier)
}
_ => self.inner.write_message_begin(identifier),
}
}
fn write_message_end(&mut self) -> ::Result<()> {
self.inner.write_message_end()
}
fn write_struct_begin(&mut self, identifier: &TStructIdentifier) -> ::Result<()> {
self.inner.write_struct_begin(identifier)
}
fn write_struct_end(&mut self) -> ::Result<()> {
self.inner.write_struct_end()
}
fn write_field_begin(&mut self, identifier: &TFieldIdentifier) -> ::Result<()> {
self.inner.write_field_begin(identifier)
}
fn write_field_end(&mut self) -> ::Result<()> {
self.inner.write_field_end()
}
fn write_field_stop(&mut self) -> ::Result<()> {
self.inner.write_field_stop()
}
fn write_bytes(&mut self, b: &[u8]) -> ::Result<()> {
self.inner.write_bytes(b)
}
fn write_bool(&mut self, b: bool) -> ::Result<()> {
self.inner.write_bool(b)
}
fn write_i8(&mut self, i: i8) -> ::Result<()> {
self.inner.write_i8(i)
}
fn write_i16(&mut self, i: i16) -> ::Result<()> {
self.inner.write_i16(i)
}
fn write_i32(&mut self, i: i32) -> ::Result<()> {
self.inner.write_i32(i)
}
fn write_i64(&mut self, i: i64) -> ::Result<()> {
self.inner.write_i64(i)
}
fn write_double(&mut self, d: f64) -> ::Result<()> {
self.inner.write_double(d)
}
fn write_string(&mut self, s: &str) -> ::Result<()> {
self.inner.write_string(s)
}
fn write_list_begin(&mut self, identifier: &TListIdentifier) -> ::Result<()> {
self.inner.write_list_begin(identifier)
}
fn write_list_end(&mut self) -> ::Result<()> {
self.inner.write_list_end()
}
fn write_set_begin(&mut self, identifier: &TSetIdentifier) -> ::Result<()> {
self.inner.write_set_begin(identifier)
}
fn write_set_end(&mut self) -> ::Result<()> {
self.inner.write_set_end()
}
fn write_map_begin(&mut self, identifier: &TMapIdentifier) -> ::Result<()> {
self.inner.write_map_begin(identifier)
}
fn write_map_end(&mut self) -> ::Result<()> {
self.inner.write_map_end()
}
fn flush(&mut self) -> ::Result<()> {
self.inner.flush()
}
// utility
//
fn write_byte(&mut self, b: u8) -> ::Result<()> {
self.inner.write_byte(b)
}
}
#[cfg(test)]
mod tests {
use protocol::{TBinaryOutputProtocol, TMessageIdentifier, TMessageType, TOutputProtocol};
use transport::{TBufferChannel, TIoChannel, WriteHalf};
use super::*;
#[test]
fn must_write_message_begin_with_prefixed_service_name() {
let mut o_prot = test_objects();
let ident = TMessageIdentifier::new("bar", TMessageType::Call, 2);
assert_success!(o_prot.write_message_begin(&ident));
let expected: [u8; 19] = [
0x80,
0x01, /* protocol identifier */
0x00,
0x01, /* message type */
0x00,
0x00,
0x00,
0x07,
0x66,
0x6F,
0x6F, /* "foo" */
0x3A, /* ":" */
0x62,
0x61,
0x72, /* "bar" */
0x00,
0x00,
0x00,
0x02 /* sequence number */,
];
assert_eq!(o_prot.inner.transport.write_bytes(), expected);
}
fn test_objects
()
-> TMultiplexedOutputProtocol<TBinaryOutputProtocol<WriteHalf<TBufferChannel>>>
{
let c = TBufferChannel::with_capacity(40, 40);
let (_, w_chan) = c.split().unwrap();
let prot = TBinaryOutputProtocol::new(w_chan, true);
TMultiplexedOutputProtocol::new("foo", prot)
}
}

View file

@ -0,0 +1,198 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
use std::convert::Into;
use ProtocolErrorKind;
use super::{TFieldIdentifier, TInputProtocol, TListIdentifier, TMapIdentifier, TMessageIdentifier,
TSetIdentifier, TStructIdentifier};
/// `TInputProtocol` required to use a `TMultiplexedProcessor`.
///
/// A `TMultiplexedProcessor` reads incoming message identifiers to determine to
/// which `TProcessor` requests should be forwarded. However, once read, those
/// message identifier bytes are no longer on the wire. Since downstream
/// processors expect to read message identifiers from the given input protocol
/// we need some way of supplying a `TMessageIdentifier` with the service-name
/// stripped. This implementation stores the received `TMessageIdentifier`
/// (without the service name) and passes it to the wrapped `TInputProtocol`
/// when `TInputProtocol::read_message_begin(...)` is called. It delegates all
/// other calls directly to the wrapped `TInputProtocol`.
///
/// This type **should not** be used by application code.
///
/// # Examples
///
/// Create and use a `TStoredInputProtocol`.
///
/// ```no_run
/// use thrift;
/// use thrift::protocol::{TInputProtocol, TMessageIdentifier, TMessageType, TOutputProtocol};
/// use thrift::protocol::{TBinaryInputProtocol, TBinaryOutputProtocol, TStoredInputProtocol};
/// use thrift::server::TProcessor;
/// use thrift::transport::{TIoChannel, TTcpChannel};
///
/// // sample processor
/// struct ActualProcessor;
/// impl TProcessor for ActualProcessor {
/// fn process(
/// &self,
/// _: &mut TInputProtocol,
/// _: &mut TOutputProtocol
/// ) -> thrift::Result<()> {
/// unimplemented!()
/// }
/// }
/// let processor = ActualProcessor {};
///
/// // construct the shared transport
/// let mut channel = TTcpChannel::new();
/// channel.open("localhost:9090").unwrap();
///
/// let (i_chan, o_chan) = channel.split().unwrap();
///
/// // construct the actual input and output protocols
/// let mut i_prot = TBinaryInputProtocol::new(i_chan, true);
/// let mut o_prot = TBinaryOutputProtocol::new(o_chan, true);
///
/// // message identifier received from remote and modified to remove the service name
/// let new_msg_ident = TMessageIdentifier::new("service_call", TMessageType::Call, 1);
///
/// // construct the proxy input protocol
/// let mut proxy_i_prot = TStoredInputProtocol::new(&mut i_prot, new_msg_ident);
/// let res = processor.process(&mut proxy_i_prot, &mut o_prot);
/// ```
// FIXME: implement Debug
pub struct TStoredInputProtocol<'a> {
inner: &'a mut TInputProtocol,
message_ident: Option<TMessageIdentifier>,
}
impl<'a> TStoredInputProtocol<'a> {
/// Create a `TStoredInputProtocol` that delegates all calls other than
/// `TInputProtocol::read_message_begin(...)` to a `wrapped`
/// `TInputProtocol`. `message_ident` is the modified message identifier -
/// with service name stripped - that will be passed to
/// `wrapped.read_message_begin(...)`.
pub fn new(
wrapped: &mut TInputProtocol,
message_ident: TMessageIdentifier,
) -> TStoredInputProtocol {
TStoredInputProtocol {
inner: wrapped,
message_ident: message_ident.into(),
}
}
}
impl<'a> TInputProtocol for TStoredInputProtocol<'a> {
fn read_message_begin(&mut self) -> ::Result<TMessageIdentifier> {
self.message_ident
.take()
.ok_or_else(
|| {
::errors::new_protocol_error(
ProtocolErrorKind::Unknown,
"message identifier already read",
)
},
)
}
fn read_message_end(&mut self) -> ::Result<()> {
self.inner.read_message_end()
}
fn read_struct_begin(&mut self) -> ::Result<Option<TStructIdentifier>> {
self.inner.read_struct_begin()
}
fn read_struct_end(&mut self) -> ::Result<()> {
self.inner.read_struct_end()
}
fn read_field_begin(&mut self) -> ::Result<TFieldIdentifier> {
self.inner.read_field_begin()
}
fn read_field_end(&mut self) -> ::Result<()> {
self.inner.read_field_end()
}
fn read_bytes(&mut self) -> ::Result<Vec<u8>> {
self.inner.read_bytes()
}
fn read_bool(&mut self) -> ::Result<bool> {
self.inner.read_bool()
}
fn read_i8(&mut self) -> ::Result<i8> {
self.inner.read_i8()
}
fn read_i16(&mut self) -> ::Result<i16> {
self.inner.read_i16()
}
fn read_i32(&mut self) -> ::Result<i32> {
self.inner.read_i32()
}
fn read_i64(&mut self) -> ::Result<i64> {
self.inner.read_i64()
}
fn read_double(&mut self) -> ::Result<f64> {
self.inner.read_double()
}
fn read_string(&mut self) -> ::Result<String> {
self.inner.read_string()
}
fn read_list_begin(&mut self) -> ::Result<TListIdentifier> {
self.inner.read_list_begin()
}
fn read_list_end(&mut self) -> ::Result<()> {
self.inner.read_list_end()
}
fn read_set_begin(&mut self) -> ::Result<TSetIdentifier> {
self.inner.read_set_begin()
}
fn read_set_end(&mut self) -> ::Result<()> {
self.inner.read_set_end()
}
fn read_map_begin(&mut self) -> ::Result<TMapIdentifier> {
self.inner.read_map_begin()
}
fn read_map_end(&mut self) -> ::Result<()> {
self.inner.read_map_end()
}
// utility
//
fn read_byte(&mut self) -> ::Result<u8> {
self.inner.read_byte()
}
}

View file

@ -0,0 +1,124 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
//! Types used to implement a Thrift server.
use {ApplicationError, ApplicationErrorKind};
use protocol::{TInputProtocol, TMessageIdentifier, TMessageType, TOutputProtocol};
mod multiplexed;
mod threaded;
pub use self::multiplexed::TMultiplexedProcessor;
pub use self::threaded::TServer;
/// Handles incoming Thrift messages and dispatches them to the user-defined
/// handler functions.
///
/// An implementation is auto-generated for each Thrift service. When used by a
/// server (for example, a `TSimpleServer`), it will demux incoming service
/// calls and invoke the corresponding user-defined handler function.
///
/// # Examples
///
/// Create and start a server using the auto-generated `TProcessor` for
/// a Thrift service `SimpleService`.
///
/// ```no_run
/// use thrift;
/// use thrift::protocol::{TInputProtocol, TOutputProtocol};
/// use thrift::server::TProcessor;
///
/// //
/// // auto-generated
/// //
///
/// // processor for `SimpleService`
/// struct SimpleServiceSyncProcessor;
/// impl SimpleServiceSyncProcessor {
/// fn new<H: SimpleServiceSyncHandler>(processor: H) -> SimpleServiceSyncProcessor {
/// unimplemented!();
/// }
/// }
///
/// // `TProcessor` implementation for `SimpleService`
/// impl TProcessor for SimpleServiceSyncProcessor {
/// fn process(&self, i: &mut TInputProtocol, o: &mut TOutputProtocol) -> thrift::Result<()> {
/// unimplemented!();
/// }
/// }
///
/// // service functions for SimpleService
/// trait SimpleServiceSyncHandler {
/// fn service_call(&self) -> thrift::Result<()>;
/// }
///
/// //
/// // user-code follows
/// //
///
/// // define a handler that will be invoked when `service_call` is received
/// struct SimpleServiceHandlerImpl;
/// impl SimpleServiceSyncHandler for SimpleServiceHandlerImpl {
/// fn service_call(&self) -> thrift::Result<()> {
/// unimplemented!();
/// }
/// }
///
/// // instantiate the processor
/// let processor = SimpleServiceSyncProcessor::new(SimpleServiceHandlerImpl {});
///
/// // at this point you can pass the processor to the server
/// // let server = TServer::new(..., processor);
/// ```
pub trait TProcessor {
/// Process a Thrift service call.
///
/// Reads arguments from `i`, executes the user's handler code, and writes
/// the response to `o`.
///
/// Returns `()` if the handler was executed; `Err` otherwise.
fn process(&self, i: &mut TInputProtocol, o: &mut TOutputProtocol) -> ::Result<()>;
}
/// Convenience function used in generated `TProcessor` implementations to
/// return an `ApplicationError` if thrift message processing failed.
pub fn handle_process_result(
msg_ident: &TMessageIdentifier,
res: ::Result<()>,
o_prot: &mut TOutputProtocol,
) -> ::Result<()> {
if let Err(e) = res {
let e = match e {
::Error::Application(a) => a,
_ => ApplicationError::new(ApplicationErrorKind::Unknown, format!("{:?}", e)),
};
let ident = TMessageIdentifier::new(
msg_ident.name.clone(),
TMessageType::Exception,
msg_ident.sequence_number,
);
o_prot.write_message_begin(&ident)?;
::Error::write_application_error_to_out_protocol(&e, o_prot)?;
o_prot.write_message_end()?;
o_prot.flush()
} else {
Ok(())
}
}

View file

@ -0,0 +1,344 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
use std::collections::HashMap;
use std::fmt;
use std::fmt::{Debug, Formatter};
use std::convert::Into;
use std::sync::{Arc, Mutex};
use protocol::{TInputProtocol, TMessageIdentifier, TOutputProtocol, TStoredInputProtocol};
use super::{TProcessor, handle_process_result};
const MISSING_SEPARATOR_AND_NO_DEFAULT: &'static str = "missing service separator and no default processor set";
type ThreadSafeProcessor = Box<TProcessor + Send + Sync>;
/// A `TProcessor` that can demux service calls to multiple underlying
/// Thrift services.
///
/// Users register service-specific `TProcessor` instances with a
/// `TMultiplexedProcessor`, and then register that processor with a server
/// implementation. Following that, all incoming service calls are automatically
/// routed to the service-specific `TProcessor`.
///
/// A `TMultiplexedProcessor` can only handle messages sent by a
/// `TMultiplexedOutputProtocol`.
#[derive(Default)]
pub struct TMultiplexedProcessor {
stored: Mutex<StoredProcessors>,
}
#[derive(Default)]
struct StoredProcessors {
processors: HashMap<String, Arc<ThreadSafeProcessor>>,
default_processor: Option<Arc<ThreadSafeProcessor>>,
}
impl TMultiplexedProcessor {
/// Create a new `TMultiplexedProcessor` with no registered service-specific
/// processors.
pub fn new() -> TMultiplexedProcessor {
TMultiplexedProcessor {
stored: Mutex::new(
StoredProcessors {
processors: HashMap::new(),
default_processor: None,
},
),
}
}
/// Register a service-specific `processor` for the service named
/// `service_name`. This implementation is also backwards-compatible with
/// non-multiplexed clients. Set `as_default` to `true` to allow
/// non-namespaced requests to be dispatched to a default processor.
///
/// Returns success if a new entry was inserted. Returns an error if:
/// * A processor exists for `service_name`
/// * You attempt to register a processor as default, and an existing default exists
#[cfg_attr(feature = "cargo-clippy", allow(map_entry))]
pub fn register<S: Into<String>>(
&mut self,
service_name: S,
processor: Box<TProcessor + Send + Sync>,
as_default: bool,
) -> ::Result<()> {
let mut stored = self.stored.lock().unwrap();
let name = service_name.into();
if !stored.processors.contains_key(&name) {
let processor = Arc::new(processor);
if as_default {
if stored.default_processor.is_none() {
stored.processors.insert(name, processor.clone());
stored.default_processor = Some(processor.clone());
Ok(())
} else {
Err("cannot reset default processor".into())
}
} else {
stored.processors.insert(name, processor);
Ok(())
}
} else {
Err(format!("cannot overwrite existing processor for service {}", name).into(),)
}
}
fn process_message(
&self,
msg_ident: &TMessageIdentifier,
i_prot: &mut TInputProtocol,
o_prot: &mut TOutputProtocol,
) -> ::Result<()> {
let (svc_name, svc_call) = split_ident_name(&msg_ident.name);
debug!("routing svc_name {:?} svc_call {}", &svc_name, &svc_call);
let processor: Option<Arc<ThreadSafeProcessor>> = {
let stored = self.stored.lock().unwrap();
if let Some(name) = svc_name {
stored.processors.get(name).cloned()
} else {
stored.default_processor.clone()
}
};
match processor {
Some(arc) => {
let new_msg_ident = TMessageIdentifier::new(
svc_call,
msg_ident.message_type,
msg_ident.sequence_number,
);
let mut proxy_i_prot = TStoredInputProtocol::new(i_prot, new_msg_ident);
(*arc).process(&mut proxy_i_prot, o_prot)
}
None => Err(missing_processor_message(svc_name).into()),
}
}
}
impl TProcessor for TMultiplexedProcessor {
fn process(&self, i_prot: &mut TInputProtocol, o_prot: &mut TOutputProtocol) -> ::Result<()> {
let msg_ident = i_prot.read_message_begin()?;
debug!("process incoming msg id:{:?}", &msg_ident);
let res = self.process_message(&msg_ident, i_prot, o_prot);
handle_process_result(&msg_ident, res, o_prot)
}
}
impl Debug for TMultiplexedProcessor {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
let stored = self.stored.lock().unwrap();
write!(
f,
"TMultiplexedProcess {{ registered_count: {:?} default: {:?} }}",
stored.processors.keys().len(),
stored.default_processor.is_some()
)
}
}
fn split_ident_name(ident_name: &str) -> (Option<&str>, &str) {
ident_name
.find(':')
.map(
|pos| {
let (svc_name, svc_call) = ident_name.split_at(pos);
let (_, svc_call) = svc_call.split_at(1); // remove colon from service call name
(Some(svc_name), svc_call)
},
)
.or_else(|| Some((None, ident_name)))
.unwrap()
}
fn missing_processor_message(svc_name: Option<&str>) -> String {
match svc_name {
Some(name) => format!("no processor found for service {}", name),
None => MISSING_SEPARATOR_AND_NO_DEFAULT.to_owned(),
}
}
#[cfg(test)]
mod tests {
use std::convert::Into;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use {ApplicationError, ApplicationErrorKind};
use protocol::{TBinaryInputProtocol, TBinaryOutputProtocol, TMessageIdentifier, TMessageType};
use transport::{ReadHalf, TBufferChannel, TIoChannel, WriteHalf};
use super::*;
#[test]
fn should_split_name_into_proper_separator_and_service_call() {
let ident_name = "foo:bar_call";
let (serv, call) = split_ident_name(&ident_name);
assert_eq!(serv, Some("foo"));
assert_eq!(call, "bar_call");
}
#[test]
fn should_return_full_ident_if_no_separator_exists() {
let ident_name = "bar_call";
let (serv, call) = split_ident_name(&ident_name);
assert_eq!(serv, None);
assert_eq!(call, "bar_call");
}
#[test]
fn should_write_error_if_no_separator_found_and_no_default_processor_exists() {
let (mut i, mut o) = build_objects();
let sent_ident = TMessageIdentifier::new("foo", TMessageType::Call, 10);
o.write_message_begin(&sent_ident).unwrap();
o.flush().unwrap();
o.transport.copy_write_buffer_to_read_buffer();
o.transport.empty_write_buffer();
let p = TMultiplexedProcessor::new();
p.process(&mut i, &mut o).unwrap(); // at this point an error should be written out
i.transport
.set_readable_bytes(&o.transport.write_bytes());
let rcvd_ident = i.read_message_begin().unwrap();
let expected_ident = TMessageIdentifier::new("foo", TMessageType::Exception, 10);
assert_eq!(rcvd_ident, expected_ident);
let rcvd_err = ::Error::read_application_error_from_in_protocol(&mut i).unwrap();
let expected_err = ApplicationError::new(
ApplicationErrorKind::Unknown,
MISSING_SEPARATOR_AND_NO_DEFAULT,
);
assert_eq!(rcvd_err, expected_err);
}
#[test]
fn should_write_error_if_separator_exists_and_no_processor_found() {
let (mut i, mut o) = build_objects();
let sent_ident = TMessageIdentifier::new("missing:call", TMessageType::Call, 10);
o.write_message_begin(&sent_ident).unwrap();
o.flush().unwrap();
o.transport.copy_write_buffer_to_read_buffer();
o.transport.empty_write_buffer();
let p = TMultiplexedProcessor::new();
p.process(&mut i, &mut o).unwrap(); // at this point an error should be written out
i.transport
.set_readable_bytes(&o.transport.write_bytes());
let rcvd_ident = i.read_message_begin().unwrap();
let expected_ident = TMessageIdentifier::new("missing:call", TMessageType::Exception, 10);
assert_eq!(rcvd_ident, expected_ident);
let rcvd_err = ::Error::read_application_error_from_in_protocol(&mut i).unwrap();
let expected_err = ApplicationError::new(
ApplicationErrorKind::Unknown,
missing_processor_message(Some("missing")),
);
assert_eq!(rcvd_err, expected_err);
}
#[derive(Default)]
struct Service {
pub invoked: Arc<AtomicBool>,
}
impl TProcessor for Service {
fn process(&self, _: &mut TInputProtocol, _: &mut TOutputProtocol) -> ::Result<()> {
let res = self.invoked
.compare_and_swap(false, true, Ordering::Relaxed);
if res {
Ok(())
} else {
Err("failed swap".into())
}
}
}
#[test]
fn should_route_call_to_correct_processor() {
let (mut i, mut o) = build_objects();
// build the services
let svc_1 = Service { invoked: Arc::new(AtomicBool::new(false)) };
let atm_1 = svc_1.invoked.clone();
let svc_2 = Service { invoked: Arc::new(AtomicBool::new(false)) };
let atm_2 = svc_2.invoked.clone();
// register them
let mut p = TMultiplexedProcessor::new();
p.register("service_1", Box::new(svc_1), false).unwrap();
p.register("service_2", Box::new(svc_2), false).unwrap();
// make the service call
let sent_ident = TMessageIdentifier::new("service_1:call", TMessageType::Call, 10);
o.write_message_begin(&sent_ident).unwrap();
o.flush().unwrap();
o.transport.copy_write_buffer_to_read_buffer();
o.transport.empty_write_buffer();
p.process(&mut i, &mut o).unwrap();
// service 1 should have been invoked, not service 2
assert_eq!(atm_1.load(Ordering::Relaxed), true);
assert_eq!(atm_2.load(Ordering::Relaxed), false);
}
#[test]
fn should_route_call_to_correct_processor_if_no_separator_exists_and_default_processor_set() {
let (mut i, mut o) = build_objects();
// build the services
let svc_1 = Service { invoked: Arc::new(AtomicBool::new(false)) };
let atm_1 = svc_1.invoked.clone();
let svc_2 = Service { invoked: Arc::new(AtomicBool::new(false)) };
let atm_2 = svc_2.invoked.clone();
// register them
let mut p = TMultiplexedProcessor::new();
p.register("service_1", Box::new(svc_1), false).unwrap();
p.register("service_2", Box::new(svc_2), true).unwrap(); // second processor is default
// make the service call (it's an old client, so we have to be backwards compatible)
let sent_ident = TMessageIdentifier::new("old_call", TMessageType::Call, 10);
o.write_message_begin(&sent_ident).unwrap();
o.flush().unwrap();
o.transport.copy_write_buffer_to_read_buffer();
o.transport.empty_write_buffer();
p.process(&mut i, &mut o).unwrap();
// service 2 should have been invoked, not service 1
assert_eq!(atm_1.load(Ordering::Relaxed), false);
assert_eq!(atm_2.load(Ordering::Relaxed), true);
}
fn build_objects()
-> (TBinaryInputProtocol<ReadHalf<TBufferChannel>>,
TBinaryOutputProtocol<WriteHalf<TBufferChannel>>)
{
let c = TBufferChannel::with_capacity(128, 128);
let (r_c, w_c) = c.split().unwrap();
(TBinaryInputProtocol::new(r_c, true), TBinaryOutputProtocol::new(w_c, true))
}
}

View file

@ -0,0 +1,240 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
use std::net::{TcpListener, TcpStream};
use std::sync::Arc;
use threadpool::ThreadPool;
use {ApplicationError, ApplicationErrorKind};
use protocol::{TInputProtocol, TInputProtocolFactory, TOutputProtocol, TOutputProtocolFactory};
use transport::{TIoChannel, TReadTransportFactory, TTcpChannel, TWriteTransportFactory};
use super::TProcessor;
/// Fixed-size thread-pool blocking Thrift server.
///
/// A `TServer` listens on a given address and submits accepted connections
/// to an **unbounded** queue. Connections from this queue are serviced by
/// the first available worker thread from a **fixed-size** thread pool. Each
/// accepted connection is handled by that worker thread, and communication
/// over this thread occurs sequentially and synchronously (i.e. calls block).
/// Accepted connections have an input half and an output half, each of which
/// uses a `TTransport` and `TInputProtocol`/`TOutputProtocol` to translate
/// messages to and from byes. Any combination of `TInputProtocol`, `TOutputProtocol`
/// and `TTransport` may be used.
///
/// # Examples
///
/// Creating and running a `TServer` using Thrift-compiler-generated
/// service code.
///
/// ```no_run
/// use thrift;
/// use thrift::protocol::{TInputProtocolFactory, TOutputProtocolFactory};
/// use thrift::protocol::{TBinaryInputProtocolFactory, TBinaryOutputProtocolFactory};
/// use thrift::protocol::{TInputProtocol, TOutputProtocol};
/// use thrift::transport::{TBufferedReadTransportFactory, TBufferedWriteTransportFactory,
/// TReadTransportFactory, TWriteTransportFactory};
/// use thrift::server::{TProcessor, TServer};
///
/// //
/// // auto-generated
/// //
///
/// // processor for `SimpleService`
/// struct SimpleServiceSyncProcessor;
/// impl SimpleServiceSyncProcessor {
/// fn new<H: SimpleServiceSyncHandler>(processor: H) -> SimpleServiceSyncProcessor {
/// unimplemented!();
/// }
/// }
///
/// // `TProcessor` implementation for `SimpleService`
/// impl TProcessor for SimpleServiceSyncProcessor {
/// fn process(&self, i: &mut TInputProtocol, o: &mut TOutputProtocol) -> thrift::Result<()> {
/// unimplemented!();
/// }
/// }
///
/// // service functions for SimpleService
/// trait SimpleServiceSyncHandler {
/// fn service_call(&self) -> thrift::Result<()>;
/// }
///
/// //
/// // user-code follows
/// //
///
/// // define a handler that will be invoked when `service_call` is received
/// struct SimpleServiceHandlerImpl;
/// impl SimpleServiceSyncHandler for SimpleServiceHandlerImpl {
/// fn service_call(&self) -> thrift::Result<()> {
/// unimplemented!();
/// }
/// }
///
/// // instantiate the processor
/// let processor = SimpleServiceSyncProcessor::new(SimpleServiceHandlerImpl {});
///
/// // instantiate the server
/// let i_tr_fact: Box<TReadTransportFactory> = Box::new(TBufferedReadTransportFactory::new());
/// let i_pr_fact: Box<TInputProtocolFactory> = Box::new(TBinaryInputProtocolFactory::new());
/// let o_tr_fact: Box<TWriteTransportFactory> = Box::new(TBufferedWriteTransportFactory::new());
/// let o_pr_fact: Box<TOutputProtocolFactory> = Box::new(TBinaryOutputProtocolFactory::new());
///
/// let mut server = TServer::new(
/// i_tr_fact,
/// i_pr_fact,
/// o_tr_fact,
/// o_pr_fact,
/// processor,
/// 10
/// );
///
/// // start listening for incoming connections
/// match server.listen("127.0.0.1:8080") {
/// Ok(_) => println!("listen completed"),
/// Err(e) => println!("listen failed with error {:?}", e),
/// }
/// ```
#[derive(Debug)]
pub struct TServer<PRC, RTF, IPF, WTF, OPF>
where
PRC: TProcessor + Send + Sync + 'static,
RTF: TReadTransportFactory + 'static,
IPF: TInputProtocolFactory + 'static,
WTF: TWriteTransportFactory + 'static,
OPF: TOutputProtocolFactory + 'static,
{
r_trans_factory: RTF,
i_proto_factory: IPF,
w_trans_factory: WTF,
o_proto_factory: OPF,
processor: Arc<PRC>,
worker_pool: ThreadPool,
}
impl<PRC, RTF, IPF, WTF, OPF> TServer<PRC, RTF, IPF, WTF, OPF>
where PRC: TProcessor + Send + Sync + 'static,
RTF: TReadTransportFactory + 'static,
IPF: TInputProtocolFactory + 'static,
WTF: TWriteTransportFactory + 'static,
OPF: TOutputProtocolFactory + 'static {
/// Create a `TServer`.
///
/// Each accepted connection has an input and output half, each of which
/// requires a `TTransport` and `TProtocol`. `TServer` uses
/// `read_transport_factory` and `input_protocol_factory` to create
/// implementations for the input, and `write_transport_factory` and
/// `output_protocol_factory` to create implementations for the output.
pub fn new(
read_transport_factory: RTF,
input_protocol_factory: IPF,
write_transport_factory: WTF,
output_protocol_factory: OPF,
processor: PRC,
num_workers: usize,
) -> TServer<PRC, RTF, IPF, WTF, OPF> {
TServer {
r_trans_factory: read_transport_factory,
i_proto_factory: input_protocol_factory,
w_trans_factory: write_transport_factory,
o_proto_factory: output_protocol_factory,
processor: Arc::new(processor),
worker_pool: ThreadPool::with_name(
"Thrift service processor".to_owned(),
num_workers,
),
}
}
/// Listen for incoming connections on `listen_address`.
///
/// `listen_address` should be in the form `host:port`,
/// for example: `127.0.0.1:8080`.
///
/// Return `()` if successful.
///
/// Return `Err` when the server cannot bind to `listen_address` or there
/// is an unrecoverable error.
pub fn listen(&mut self, listen_address: &str) -> ::Result<()> {
let listener = TcpListener::bind(listen_address)?;
for stream in listener.incoming() {
match stream {
Ok(s) => {
let (i_prot, o_prot) = self.new_protocols_for_connection(s)?;
let processor = self.processor.clone();
self.worker_pool
.execute(move || handle_incoming_connection(processor, i_prot, o_prot),);
}
Err(e) => {
warn!("failed to accept remote connection with error {:?}", e);
}
}
}
Err(
::Error::Application(
ApplicationError {
kind: ApplicationErrorKind::Unknown,
message: "aborted listen loop".into(),
},
),
)
}
fn new_protocols_for_connection(
&mut self,
stream: TcpStream,
) -> ::Result<(Box<TInputProtocol + Send>, Box<TOutputProtocol + Send>)> {
// create the shared tcp stream
let channel = TTcpChannel::with_stream(stream);
// split it into two - one to be owned by the
// input tran/proto and the other by the output
let (r_chan, w_chan) = channel.split()?;
// input protocol and transport
let r_tran = self.r_trans_factory.create(Box::new(r_chan));
let i_prot = self.i_proto_factory.create(r_tran);
// output protocol and transport
let w_tran = self.w_trans_factory.create(Box::new(w_chan));
let o_prot = self.o_proto_factory.create(w_tran);
Ok((i_prot, o_prot))
}
}
fn handle_incoming_connection<PRC>(
processor: Arc<PRC>,
i_prot: Box<TInputProtocol>,
o_prot: Box<TOutputProtocol>,
) where
PRC: TProcessor,
{
let mut i_prot = i_prot;
let mut o_prot = o_prot;
loop {
let r = processor.process(&mut *i_prot, &mut *o_prot);
if let Err(e) = r {
warn!("processor completed with error: {:?}", e);
break;
}
}
}

View file

@ -0,0 +1,480 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
use std::cmp;
use std::io;
use std::io::{Read, Write};
use super::{TReadTransport, TReadTransportFactory, TWriteTransport, TWriteTransportFactory};
/// Default capacity of the read buffer in bytes.
const READ_CAPACITY: usize = 4096;
/// Default capacity of the write buffer in bytes..
const WRITE_CAPACITY: usize = 4096;
/// Transport that reads messages via an internal buffer.
///
/// A `TBufferedReadTransport` maintains a fixed-size internal read buffer.
/// On a call to `TBufferedReadTransport::read(...)` one full message - both
/// fixed-length header and bytes - is read from the wrapped channel and buffered.
/// Subsequent read calls are serviced from the internal buffer until it is
/// exhausted, at which point the next full message is read from the wrapped
/// channel.
///
/// # Examples
///
/// Create and use a `TBufferedReadTransport`.
///
/// ```no_run
/// use std::io::Read;
/// use thrift::transport::{TBufferedReadTransport, TTcpChannel};
///
/// let mut c = TTcpChannel::new();
/// c.open("localhost:9090").unwrap();
///
/// let mut t = TBufferedReadTransport::new(c);
///
/// t.read(&mut vec![0u8; 1]).unwrap();
/// ```
#[derive(Debug)]
pub struct TBufferedReadTransport<C>
where
C: Read,
{
buf: Box<[u8]>,
pos: usize,
cap: usize,
chan: C,
}
impl<C> TBufferedReadTransport<C>
where
C: Read,
{
/// Create a `TBufferedTransport` with default-sized internal read and
/// write buffers that wraps the given `TIoChannel`.
pub fn new(channel: C) -> TBufferedReadTransport<C> {
TBufferedReadTransport::with_capacity(READ_CAPACITY, channel)
}
/// Create a `TBufferedTransport` with an internal read buffer of size
/// `read_capacity` and an internal write buffer of size
/// `write_capacity` that wraps the given `TIoChannel`.
pub fn with_capacity(read_capacity: usize, channel: C) -> TBufferedReadTransport<C> {
TBufferedReadTransport {
buf: vec![0; read_capacity].into_boxed_slice(),
pos: 0,
cap: 0,
chan: channel,
}
}
fn get_bytes(&mut self) -> io::Result<&[u8]> {
if self.cap - self.pos == 0 {
self.pos = 0;
self.cap = self.chan.read(&mut self.buf)?;
}
Ok(&self.buf[self.pos..self.cap])
}
fn consume(&mut self, consumed: usize) {
// TODO: was a bug here += <-- test somehow
self.pos = cmp::min(self.cap, self.pos + consumed);
}
}
impl<C> Read for TBufferedReadTransport<C>
where
C: Read,
{
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let mut bytes_read = 0;
loop {
let nread = {
let avail_bytes = self.get_bytes()?;
let avail_space = buf.len() - bytes_read;
let nread = cmp::min(avail_space, avail_bytes.len());
buf[bytes_read..(bytes_read + nread)].copy_from_slice(&avail_bytes[..nread]);
nread
};
self.consume(nread);
bytes_read += nread;
if bytes_read == buf.len() || nread == 0 {
break;
}
}
Ok(bytes_read)
}
}
/// Factory for creating instances of `TBufferedReadTransport`.
#[derive(Default)]
pub struct TBufferedReadTransportFactory;
impl TBufferedReadTransportFactory {
pub fn new() -> TBufferedReadTransportFactory {
TBufferedReadTransportFactory {}
}
}
impl TReadTransportFactory for TBufferedReadTransportFactory {
/// Create a `TBufferedReadTransport`.
fn create(&self, channel: Box<Read + Send>) -> Box<TReadTransport + Send> {
Box::new(TBufferedReadTransport::new(channel))
}
}
/// Transport that writes messages via an internal buffer.
///
/// A `TBufferedWriteTransport` maintains a fixed-size internal write buffer.
/// All writes are made to this buffer and are sent to the wrapped channel only
/// when `TBufferedWriteTransport::flush()` is called. On a flush a fixed-length
/// header with a count of the buffered bytes is written, followed by the bytes
/// themselves.
///
/// # Examples
///
/// Create and use a `TBufferedWriteTransport`.
///
/// ```no_run
/// use std::io::Write;
/// use thrift::transport::{TBufferedWriteTransport, TTcpChannel};
///
/// let mut c = TTcpChannel::new();
/// c.open("localhost:9090").unwrap();
///
/// let mut t = TBufferedWriteTransport::new(c);
///
/// t.write(&[0x00]).unwrap();
/// t.flush().unwrap();
/// ```
#[derive(Debug)]
pub struct TBufferedWriteTransport<C>
where
C: Write,
{
buf: Vec<u8>,
cap: usize,
channel: C,
}
impl<C> TBufferedWriteTransport<C>
where
C: Write,
{
/// Create a `TBufferedTransport` with default-sized internal read and
/// write buffers that wraps the given `TIoChannel`.
pub fn new(channel: C) -> TBufferedWriteTransport<C> {
TBufferedWriteTransport::with_capacity(WRITE_CAPACITY, channel)
}
/// Create a `TBufferedTransport` with an internal read buffer of size
/// `read_capacity` and an internal write buffer of size
/// `write_capacity` that wraps the given `TIoChannel`.
pub fn with_capacity(write_capacity: usize, channel: C) -> TBufferedWriteTransport<C> {
assert!(write_capacity > 0, "write buffer size must be a positive integer");
TBufferedWriteTransport {
buf: Vec::with_capacity(write_capacity),
cap: write_capacity,
channel: channel,
}
}
}
impl<C> Write for TBufferedWriteTransport<C>
where
C: Write,
{
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
if !buf.is_empty() {
let mut avail_bytes;
loop {
avail_bytes = cmp::min(buf.len(), self.cap - self.buf.len());
if avail_bytes == 0 {
self.flush()?;
} else {
break;
}
}
let avail_bytes = avail_bytes;
self.buf.extend_from_slice(&buf[..avail_bytes]);
assert!(self.buf.len() <= self.cap, "copy overflowed buffer");
Ok(avail_bytes)
} else {
Ok(0)
}
}
fn flush(&mut self) -> io::Result<()> {
self.channel.write_all(&self.buf)?;
self.channel.flush()?;
self.buf.clear();
Ok(())
}
}
/// Factory for creating instances of `TBufferedWriteTransport`.
#[derive(Default)]
pub struct TBufferedWriteTransportFactory;
impl TBufferedWriteTransportFactory {
pub fn new() -> TBufferedWriteTransportFactory {
TBufferedWriteTransportFactory {}
}
}
impl TWriteTransportFactory for TBufferedWriteTransportFactory {
/// Create a `TBufferedWriteTransport`.
fn create(&self, channel: Box<Write + Send>) -> Box<TWriteTransport + Send> {
Box::new(TBufferedWriteTransport::new(channel))
}
}
#[cfg(test)]
mod tests {
use std::io::{Read, Write};
use super::*;
use transport::TBufferChannel;
#[test]
fn must_return_zero_if_read_buffer_is_empty() {
let mem = TBufferChannel::with_capacity(10, 0);
let mut t = TBufferedReadTransport::with_capacity(10, mem);
let mut b = vec![0; 10];
let read_result = t.read(&mut b);
assert_eq!(read_result.unwrap(), 0);
}
#[test]
fn must_return_zero_if_caller_reads_into_zero_capacity_buffer() {
let mem = TBufferChannel::with_capacity(10, 0);
let mut t = TBufferedReadTransport::with_capacity(10, mem);
let read_result = t.read(&mut []);
assert_eq!(read_result.unwrap(), 0);
}
#[test]
fn must_return_zero_if_nothing_more_can_be_read() {
let mem = TBufferChannel::with_capacity(4, 0);
let mut t = TBufferedReadTransport::with_capacity(4, mem);
t.chan.set_readable_bytes(&[0, 1, 2, 3]);
// read buffer is exactly the same size as bytes available
let mut buf = vec![0u8; 4];
let read_result = t.read(&mut buf);
// we've read exactly 4 bytes
assert_eq!(read_result.unwrap(), 4);
assert_eq!(&buf, &[0, 1, 2, 3]);
// try read again
let buf_again = vec![0u8; 4];
let read_result = t.read(&mut buf);
// this time, 0 bytes and we haven't changed the buffer
assert_eq!(read_result.unwrap(), 0);
assert_eq!(&buf_again, &[0, 0, 0, 0])
}
#[test]
fn must_fill_user_buffer_with_only_as_many_bytes_as_available() {
let mem = TBufferChannel::with_capacity(4, 0);
let mut t = TBufferedReadTransport::with_capacity(4, mem);
t.chan.set_readable_bytes(&[0, 1, 2, 3]);
// read buffer is much larger than the bytes available
let mut buf = vec![0u8; 8];
let read_result = t.read(&mut buf);
// we've read exactly 4 bytes
assert_eq!(read_result.unwrap(), 4);
assert_eq!(&buf[..4], &[0, 1, 2, 3]);
// try read again
let read_result = t.read(&mut buf[4..]);
// this time, 0 bytes and we haven't changed the buffer
assert_eq!(read_result.unwrap(), 0);
assert_eq!(&buf, &[0, 1, 2, 3, 0, 0, 0, 0])
}
#[test]
fn must_read_successfully() {
// this test involves a few loops within the buffered transport
// itself where it has to drain the underlying transport in order
// to service a read
// we have a much smaller buffer than the
// underlying transport has bytes available
let mem = TBufferChannel::with_capacity(10, 0);
let mut t = TBufferedReadTransport::with_capacity(2, mem);
// fill the underlying transport's byte buffer
let mut readable_bytes = [0u8; 10];
for i in 0..10 {
readable_bytes[i] = i as u8;
}
t.chan.set_readable_bytes(&readable_bytes);
// we ask to read into a buffer that's much larger
// than the one the buffered transport has; as a result
// it's going to have to keep asking the underlying
// transport for more bytes
let mut buf = [0u8; 8];
let read_result = t.read(&mut buf);
// we should have read 8 bytes
assert_eq!(read_result.unwrap(), 8);
assert_eq!(&buf, &[0, 1, 2, 3, 4, 5, 6, 7]);
// let's clear out the buffer and try read again
for i in 0..8 {
buf[i] = 0;
}
let read_result = t.read(&mut buf);
// this time we were only able to read 2 bytes
// (all that's remaining from the underlying transport)
// let's also check that the remaining bytes are untouched
assert_eq!(read_result.unwrap(), 2);
assert_eq!(&buf[0..2], &[8, 9]);
assert_eq!(&buf[2..], &[0, 0, 0, 0, 0, 0]);
// try read again (we should get 0)
// and all the existing bytes were untouched
let read_result = t.read(&mut buf);
assert_eq!(read_result.unwrap(), 0);
assert_eq!(&buf[0..2], &[8, 9]);
assert_eq!(&buf[2..], &[0, 0, 0, 0, 0, 0]);
}
#[test]
fn must_return_error_when_nothing_can_be_written_to_underlying_channel() {
let mem = TBufferChannel::with_capacity(0, 0);
let mut t = TBufferedWriteTransport::with_capacity(1, mem);
let b = vec![0; 10];
let r = t.write(&b);
// should have written 1 byte
assert_eq!(r.unwrap(), 1);
// let's try again...
let r = t.write(&b[1..]);
// this time we'll error out because the auto-flush failed
assert!(r.is_err());
}
#[test]
fn must_return_zero_if_caller_calls_write_with_empty_buffer() {
let mem = TBufferChannel::with_capacity(0, 10);
let mut t = TBufferedWriteTransport::with_capacity(10, mem);
let r = t.write(&[]);
let expected: [u8; 0] = [];
assert_eq!(r.unwrap(), 0);
assert_eq_transport_written_bytes!(t, expected);
}
#[test]
fn must_auto_flush_if_write_buffer_full() {
let mem = TBufferChannel::with_capacity(0, 8);
let mut t = TBufferedWriteTransport::with_capacity(4, mem);
let b0 = [0x00, 0x01, 0x02, 0x03];
let b1 = [0x04, 0x05, 0x06, 0x07];
// write the first 4 bytes; we've now filled the transport's write buffer
let r = t.write(&b0);
assert_eq!(r.unwrap(), 4);
// try write the next 4 bytes; this causes the transport to auto-flush the first 4 bytes
let r = t.write(&b1);
assert_eq!(r.unwrap(), 4);
// check that in writing the second 4 bytes we auto-flushed the first 4 bytes
assert_eq_transport_num_written_bytes!(t, 4);
assert_eq_transport_written_bytes!(t, b0);
t.channel.empty_write_buffer();
// now flush the transport to push the second 4 bytes to the underlying channel
assert!(t.flush().is_ok());
// check that we wrote out the second 4 bytes
assert_eq_transport_written_bytes!(t, b1);
}
#[test]
fn must_write_to_inner_transport_on_flush() {
let mem = TBufferChannel::with_capacity(10, 10);
let mut t = TBufferedWriteTransport::new(mem);
let b: [u8; 5] = [0, 1, 2, 3, 4];
assert_eq!(t.write(&b).unwrap(), 5);
assert_eq_transport_num_written_bytes!(t, 0);
assert!(t.flush().is_ok());
assert_eq_transport_written_bytes!(t, b);
}
#[test]
fn must_write_successfully_after_flush() {
let mem = TBufferChannel::with_capacity(0, 5);
let mut t = TBufferedWriteTransport::with_capacity(5, mem);
// write and flush
let b: [u8; 5] = [0, 1, 2, 3, 4];
assert_eq!(t.write(&b).unwrap(), 5);
assert!(t.flush().is_ok());
// check the flushed bytes
assert_eq_transport_written_bytes!(t, b);
// reset our underlying transport
t.channel.empty_write_buffer();
// write and flush again
assert_eq!(t.write(&b).unwrap(), 5);
assert!(t.flush().is_ok());
// check the flushed bytes
assert_eq_transport_written_bytes!(t, b);
}
}

View file

@ -0,0 +1,468 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
use std::cmp;
use std::io;
use std::io::{Read, Write};
use super::{TReadTransport, TReadTransportFactory, TWriteTransport, TWriteTransportFactory};
/// Default capacity of the read buffer in bytes.
const READ_CAPACITY: usize = 4096;
/// Default capacity of the write buffer in bytes.
const WRITE_CAPACITY: usize = 4096;
/// Transport that reads framed messages.
///
/// A `TFramedReadTransport` maintains a fixed-size internal read buffer.
/// On a call to `TFramedReadTransport::read(...)` one full message - both
/// fixed-length header and bytes - is read from the wrapped channel and
/// buffered. Subsequent read calls are serviced from the internal buffer
/// until it is exhausted, at which point the next full message is read
/// from the wrapped channel.
///
/// # Examples
///
/// Create and use a `TFramedReadTransport`.
///
/// ```no_run
/// use std::io::Read;
/// use thrift::transport::{TFramedReadTransport, TTcpChannel};
///
/// let mut c = TTcpChannel::new();
/// c.open("localhost:9090").unwrap();
///
/// let mut t = TFramedReadTransport::new(c);
///
/// t.read(&mut vec![0u8; 1]).unwrap();
/// ```
#[derive(Debug)]
pub struct TFramedReadTransport<C>
where
C: Read,
{
buf: Vec<u8>,
pos: usize,
cap: usize,
chan: C,
}
impl<C> TFramedReadTransport<C>
where
C: Read,
{
/// Create a `TFramedReadTransport` with a default-sized
/// internal read buffer that wraps the given `TIoChannel`.
pub fn new(channel: C) -> TFramedReadTransport<C> {
TFramedReadTransport::with_capacity(READ_CAPACITY, channel)
}
/// Create a `TFramedTransport` with an internal read buffer
/// of size `read_capacity` that wraps the given `TIoChannel`.
pub fn with_capacity(read_capacity: usize, channel: C) -> TFramedReadTransport<C> {
TFramedReadTransport {
buf: vec![0; read_capacity], // FIXME: do I actually have to do this?
pos: 0,
cap: 0,
chan: channel,
}
}
}
impl<C> Read for TFramedReadTransport<C>
where
C: Read,
{
fn read(&mut self, b: &mut [u8]) -> io::Result<usize> {
if self.cap - self.pos == 0 {
let message_size = self.chan.read_i32::<BigEndian>()? as usize;
let buf_capacity = cmp::max(message_size, READ_CAPACITY);
self.buf.resize(buf_capacity, 0);
self.chan.read_exact(&mut self.buf[..message_size])?;
self.cap = message_size as usize;
self.pos = 0;
}
let nread = cmp::min(b.len(), self.cap - self.pos);
b[..nread].clone_from_slice(&self.buf[self.pos..self.pos + nread]);
self.pos += nread;
Ok(nread)
}
}
/// Factory for creating instances of `TFramedReadTransport`.
#[derive(Default)]
pub struct TFramedReadTransportFactory;
impl TFramedReadTransportFactory {
pub fn new() -> TFramedReadTransportFactory {
TFramedReadTransportFactory {}
}
}
impl TReadTransportFactory for TFramedReadTransportFactory {
/// Create a `TFramedReadTransport`.
fn create(&self, channel: Box<Read + Send>) -> Box<TReadTransport + Send> {
Box::new(TFramedReadTransport::new(channel))
}
}
/// Transport that writes framed messages.
///
/// A `TFramedWriteTransport` maintains a fixed-size internal write buffer. All
/// writes are made to this buffer and are sent to the wrapped channel only
/// when `TFramedWriteTransport::flush()` is called. On a flush a fixed-length
/// header with a count of the buffered bytes is written, followed by the bytes
/// themselves.
///
/// # Examples
///
/// Create and use a `TFramedWriteTransport`.
///
/// ```no_run
/// use std::io::Write;
/// use thrift::transport::{TFramedWriteTransport, TTcpChannel};
///
/// let mut c = TTcpChannel::new();
/// c.open("localhost:9090").unwrap();
///
/// let mut t = TFramedWriteTransport::new(c);
///
/// t.write(&[0x00]).unwrap();
/// t.flush().unwrap();
/// ```
#[derive(Debug)]
pub struct TFramedWriteTransport<C>
where
C: Write,
{
buf: Vec<u8>,
channel: C,
}
impl<C> TFramedWriteTransport<C>
where
C: Write,
{
/// Create a `TFramedWriteTransport` with default-sized internal
/// write buffer that wraps the given `TIoChannel`.
pub fn new(channel: C) -> TFramedWriteTransport<C> {
TFramedWriteTransport::with_capacity(WRITE_CAPACITY, channel)
}
/// Create a `TFramedWriteTransport` with an internal write buffer
/// of size `write_capacity` that wraps the given `TIoChannel`.
pub fn with_capacity(write_capacity: usize, channel: C) -> TFramedWriteTransport<C> {
TFramedWriteTransport {
buf: Vec::with_capacity(write_capacity),
channel,
}
}
}
impl<C> Write for TFramedWriteTransport<C>
where
C: Write,
{
fn write(&mut self, b: &[u8]) -> io::Result<usize> {
let current_capacity = self.buf.capacity();
let available_space = current_capacity - self.buf.len();
if b.len() > available_space {
let additional_space = cmp::max(b.len() - available_space, current_capacity);
self.buf.reserve(additional_space);
}
self.buf.extend_from_slice(b);
Ok(b.len())
}
fn flush(&mut self) -> io::Result<()> {
let message_size = self.buf.len();
if let 0 = message_size {
return Ok(());
} else {
self.channel
.write_i32::<BigEndian>(message_size as i32)?;
}
// will spin if the underlying channel can't be written to
let mut byte_index = 0;
while byte_index < message_size {
let nwrite = self.channel.write(&self.buf[byte_index..message_size])?;
byte_index = cmp::min(byte_index + nwrite, message_size);
}
let buf_capacity = cmp::min(self.buf.capacity(), WRITE_CAPACITY);
self.buf.resize(buf_capacity, 0);
self.buf.clear();
self.channel.flush()
}
}
/// Factory for creating instances of `TFramedWriteTransport`.
#[derive(Default)]
pub struct TFramedWriteTransportFactory;
impl TFramedWriteTransportFactory {
pub fn new() -> TFramedWriteTransportFactory {
TFramedWriteTransportFactory {}
}
}
impl TWriteTransportFactory for TFramedWriteTransportFactory {
/// Create a `TFramedWriteTransport`.
fn create(&self, channel: Box<Write + Send>) -> Box<TWriteTransport + Send> {
Box::new(TFramedWriteTransport::new(channel))
}
}
#[cfg(test)]
mod tests {
use super::*;
use ::transport::mem::TBufferChannel;
// FIXME: test a forced reserve
#[test]
fn must_read_message_smaller_than_initial_buffer_size() {
let c = TBufferChannel::with_capacity(10, 10);
let mut t = TFramedReadTransport::with_capacity(8, c);
t.chan.set_readable_bytes(
&[
0x00, 0x00, 0x00, 0x04, /* message size */
0x00, 0x01, 0x02, 0x03 /* message body */
]
);
let mut buf = vec![0; 8];
// we've read exactly 4 bytes
assert_eq!(t.read(&mut buf).unwrap(), 4);
assert_eq!(&buf[..4], &[0x00, 0x01, 0x02, 0x03]);
}
#[test]
fn must_read_message_greater_than_initial_buffer_size() {
let c = TBufferChannel::with_capacity(10, 10);
let mut t = TFramedReadTransport::with_capacity(2, c);
t.chan.set_readable_bytes(
&[
0x00, 0x00, 0x00, 0x04, /* message size */
0x00, 0x01, 0x02, 0x03 /* message body */
]
);
let mut buf = vec![0; 8];
// we've read exactly 4 bytes
assert_eq!(t.read(&mut buf).unwrap(), 4);
assert_eq!(&buf[..4], &[0x00, 0x01, 0x02, 0x03]);
}
#[test]
fn must_read_multiple_messages_in_sequence_correctly() {
let c = TBufferChannel::with_capacity(10, 10);
let mut t = TFramedReadTransport::with_capacity(2, c);
//
// 1st message
//
t.chan.set_readable_bytes(
&[
0x00, 0x00, 0x00, 0x04, /* message size */
0x00, 0x01, 0x02, 0x03 /* message body */
]
);
let mut buf = vec![0; 8];
// we've read exactly 4 bytes
assert_eq!(t.read(&mut buf).unwrap(), 4);
assert_eq!(&buf, &[0x00, 0x01, 0x02, 0x03, 0x00, 0x00, 0x00, 0x00]);
//
// 2nd message
//
t.chan.set_readable_bytes(
&[
0x00, 0x00, 0x00, 0x01, /* message size */
0x04 /* message body */
]
);
let mut buf = vec![0; 8];
// we've read exactly 1 byte
assert_eq!(t.read(&mut buf).unwrap(), 1);
assert_eq!(&buf, &[0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]);
}
#[test]
fn must_write_message_smaller_than_buffer_size() {
let mem = TBufferChannel::with_capacity(0, 0);
let mut t = TFramedWriteTransport::with_capacity(20, mem);
let b = vec![0; 10];
// should have written 10 bytes
assert_eq!(t.write(&b).unwrap(), 10);
}
#[test]
fn must_return_zero_if_caller_calls_write_with_empty_buffer() {
let mem = TBufferChannel::with_capacity(0, 10);
let mut t = TFramedWriteTransport::with_capacity(10, mem);
let expected: [u8; 0] = [];
assert_eq!(t.write(&[]).unwrap(), 0);
assert_eq_transport_written_bytes!(t, expected);
}
#[test]
fn must_write_to_inner_transport_on_flush() {
let mem = TBufferChannel::with_capacity(10, 10);
let mut t = TFramedWriteTransport::new(mem);
let b: [u8; 5] = [0x00, 0x01, 0x02, 0x03, 0x04];
assert_eq!(t.write(&b).unwrap(), 5);
assert_eq_transport_num_written_bytes!(t, 0);
assert!(t.flush().is_ok());
let expected_bytes = [
0x00, 0x00, 0x00, 0x05, /* message size */
0x00, 0x01, 0x02, 0x03, 0x04 /* message body */
];
assert_eq_transport_written_bytes!(t, expected_bytes);
}
#[test]
fn must_write_message_greater_than_buffer_size_00() {
let mem = TBufferChannel::with_capacity(0, 10);
// IMPORTANT: DO **NOT** CHANGE THE WRITE_CAPACITY OR THE NUMBER OF BYTES TO BE WRITTEN!
// these lengths were chosen to be just long enough
// that doubling the capacity is a **worse** choice than
// simply resizing the buffer to b.len()
let mut t = TFramedWriteTransport::with_capacity(1, mem);
let b = [0x00, 0x01, 0x02];
// should have written 3 bytes
assert_eq!(t.write(&b).unwrap(), 3);
assert_eq_transport_num_written_bytes!(t, 0);
assert!(t.flush().is_ok());
let expected_bytes = [
0x00, 0x00, 0x00, 0x03, /* message size */
0x00, 0x01, 0x02 /* message body */
];
assert_eq_transport_written_bytes!(t, expected_bytes);
}
#[test]
fn must_write_message_greater_than_buffer_size_01() {
let mem = TBufferChannel::with_capacity(0, 10);
// IMPORTANT: DO **NOT** CHANGE THE WRITE_CAPACITY OR THE NUMBER OF BYTES TO BE WRITTEN!
// these lengths were chosen to be just long enough
// that doubling the capacity is a **better** choice than
// simply resizing the buffer to b.len()
let mut t = TFramedWriteTransport::with_capacity(2, mem);
let b = [0x00, 0x01, 0x02];
// should have written 3 bytes
assert_eq!(t.write(&b).unwrap(), 3);
assert_eq_transport_num_written_bytes!(t, 0);
assert!(t.flush().is_ok());
let expected_bytes = [
0x00, 0x00, 0x00, 0x03, /* message size */
0x00, 0x01, 0x02 /* message body */
];
assert_eq_transport_written_bytes!(t, expected_bytes);
}
#[test]
fn must_return_error_if_nothing_can_be_written_to_inner_transport_on_flush() {
let mem = TBufferChannel::with_capacity(0, 0);
let mut t = TFramedWriteTransport::with_capacity(1, mem);
let b = vec![0; 10];
// should have written 10 bytes
assert_eq!(t.write(&b).unwrap(), 10);
// let's flush
let r = t.flush();
// this time we'll error out because the flush can't write to the underlying channel
assert!(r.is_err());
}
#[test]
fn must_write_successfully_after_flush() {
// IMPORTANT: write capacity *MUST* be greater
// than message sizes used in this test + 4-byte frame header
let mem = TBufferChannel::with_capacity(0, 10);
let mut t = TFramedWriteTransport::with_capacity(5, mem);
// write and flush
let first_message: [u8; 5] = [0x00, 0x01, 0x02, 0x03, 0x04];
assert_eq!(t.write(&first_message).unwrap(), 5);
assert!(t.flush().is_ok());
let mut expected = Vec::new();
expected.write_all(&[0x00, 0x00, 0x00, 0x05]).unwrap(); // message size
expected.extend_from_slice(&first_message);
// check the flushed bytes
assert_eq!(t.channel.write_bytes(), expected);
// reset our underlying transport
t.channel.empty_write_buffer();
let second_message: [u8; 3] = [0x05, 0x06, 0x07];
assert_eq!(t.write(&second_message).unwrap(), 3);
assert!(t.flush().is_ok());
expected.clear();
expected.write_all(&[0x00, 0x00, 0x00, 0x03]).unwrap(); // message size
expected.extend_from_slice(&second_message);
// check the flushed bytes
assert_eq!(t.channel.write_bytes(), expected);
}
}

View file

@ -0,0 +1,393 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
use std::cmp;
use std::io;
use std::sync::{Arc, Mutex};
use super::{ReadHalf, TIoChannel, WriteHalf};
/// In-memory read and write channel with fixed-size read and write buffers.
///
/// On a `write` bytes are written to the internal write buffer. Writes are no
/// longer accepted once this buffer is full. Callers must `empty_write_buffer()`
/// before subsequent writes are accepted.
///
/// You can set readable bytes in the internal read buffer by filling it with
/// `set_readable_bytes(...)`. Callers can then read until the buffer is
/// depleted. No further reads are accepted until the internal read buffer is
/// replenished again.
#[derive(Debug)]
pub struct TBufferChannel {
read: Arc<Mutex<ReadData>>,
write: Arc<Mutex<WriteData>>,
}
#[derive(Debug)]
struct ReadData {
buf: Box<[u8]>,
pos: usize,
idx: usize,
cap: usize,
}
#[derive(Debug)]
struct WriteData {
buf: Box<[u8]>,
pos: usize,
cap: usize,
}
impl TBufferChannel {
/// Constructs a new, empty `TBufferChannel` with the given
/// read buffer capacity and write buffer capacity.
pub fn with_capacity(read_capacity: usize, write_capacity: usize) -> TBufferChannel {
TBufferChannel {
read: Arc::new(
Mutex::new(
ReadData {
buf: vec![0; read_capacity].into_boxed_slice(),
idx: 0,
pos: 0,
cap: read_capacity,
},
),
),
write: Arc::new(
Mutex::new(
WriteData {
buf: vec![0; write_capacity].into_boxed_slice(),
pos: 0,
cap: write_capacity,
},
),
),
}
}
/// Return a copy of the bytes held by the internal read buffer.
/// Returns an empty vector if no readable bytes are present.
pub fn read_bytes(&self) -> Vec<u8> {
let rdata = self.read.as_ref().lock().unwrap();
let mut buf = vec![0u8; rdata.idx];
buf.copy_from_slice(&rdata.buf[..rdata.idx]);
buf
}
// FIXME: do I really need this API call?
// FIXME: should this simply reset to the last set of readable bytes?
/// Reset the number of readable bytes to zero.
///
/// Subsequent calls to `read` will return nothing.
pub fn empty_read_buffer(&mut self) {
let mut rdata = self.read.as_ref().lock().unwrap();
rdata.pos = 0;
rdata.idx = 0;
}
/// Copy bytes from the source buffer `buf` into the internal read buffer,
/// overwriting any existing bytes. Returns the number of bytes copied,
/// which is `min(buf.len(), internal_read_buf.len())`.
pub fn set_readable_bytes(&mut self, buf: &[u8]) -> usize {
self.empty_read_buffer();
let mut rdata = self.read.as_ref().lock().unwrap();
let max_bytes = cmp::min(rdata.cap, buf.len());
rdata.buf[..max_bytes].clone_from_slice(&buf[..max_bytes]);
rdata.idx = max_bytes;
max_bytes
}
/// Return a copy of the bytes held by the internal write buffer.
/// Returns an empty vector if no bytes were written.
pub fn write_bytes(&self) -> Vec<u8> {
let wdata = self.write.as_ref().lock().unwrap();
let mut buf = vec![0u8; wdata.pos];
buf.copy_from_slice(&wdata.buf[..wdata.pos]);
buf
}
/// Resets the internal write buffer, making it seem like no bytes were
/// written. Calling `write_buffer` after this returns an empty vector.
pub fn empty_write_buffer(&mut self) {
let mut wdata = self.write.as_ref().lock().unwrap();
wdata.pos = 0;
}
/// Overwrites the contents of the read buffer with the contents of the
/// write buffer. The write buffer is emptied after this operation.
pub fn copy_write_buffer_to_read_buffer(&mut self) {
// FIXME: redo this entire method
let buf = {
let wdata = self.write.as_ref().lock().unwrap();
let b = &wdata.buf[..wdata.pos];
let mut b_ret = vec![0; b.len()];
b_ret.copy_from_slice(b);
b_ret
};
let bytes_copied = self.set_readable_bytes(&buf);
assert_eq!(bytes_copied, buf.len());
self.empty_write_buffer();
}
}
impl TIoChannel for TBufferChannel {
fn split(self) -> ::Result<(ReadHalf<Self>, WriteHalf<Self>)>
where
Self: Sized,
{
Ok(
(ReadHalf {
handle: TBufferChannel {
read: self.read.clone(),
write: self.write.clone(),
},
},
WriteHalf {
handle: TBufferChannel {
read: self.read.clone(),
write: self.write.clone(),
},
}),
)
}
}
impl io::Read for TBufferChannel {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let mut rdata = self.read.as_ref().lock().unwrap();
let nread = cmp::min(buf.len(), rdata.idx - rdata.pos);
buf[..nread].clone_from_slice(&rdata.buf[rdata.pos..rdata.pos + nread]);
rdata.pos += nread;
Ok(nread)
}
}
impl io::Write for TBufferChannel {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let mut wdata = self.write.as_ref().lock().unwrap();
let nwrite = cmp::min(buf.len(), wdata.cap - wdata.pos);
let (start, end) = (wdata.pos, wdata.pos + nwrite);
wdata.buf[start..end].clone_from_slice(&buf[..nwrite]);
wdata.pos += nwrite;
Ok(nwrite)
}
fn flush(&mut self) -> io::Result<()> {
Ok(()) // nothing to do on flush
}
}
#[cfg(test)]
mod tests {
use std::io::{Read, Write};
use super::TBufferChannel;
#[test]
fn must_empty_write_buffer() {
let mut t = TBufferChannel::with_capacity(0, 1);
let bytes_to_write: [u8; 1] = [0x01];
let result = t.write(&bytes_to_write);
assert_eq!(result.unwrap(), 1);
assert_eq!(&t.write_bytes(), &bytes_to_write);
t.empty_write_buffer();
assert_eq!(t.write_bytes().len(), 0);
}
#[test]
fn must_accept_writes_after_buffer_emptied() {
let mut t = TBufferChannel::with_capacity(0, 2);
let bytes_to_write: [u8; 2] = [0x01, 0x02];
// first write (all bytes written)
let result = t.write(&bytes_to_write);
assert_eq!(result.unwrap(), 2);
assert_eq!(&t.write_bytes(), &bytes_to_write);
// try write again (nothing should be written)
let result = t.write(&bytes_to_write);
assert_eq!(result.unwrap(), 0);
assert_eq!(&t.write_bytes(), &bytes_to_write); // still the same as before
// now reset the buffer
t.empty_write_buffer();
assert_eq!(t.write_bytes().len(), 0);
// now try write again - the write should succeed
let result = t.write(&bytes_to_write);
assert_eq!(result.unwrap(), 2);
assert_eq!(&t.write_bytes(), &bytes_to_write);
}
#[test]
fn must_accept_multiple_writes_until_buffer_is_full() {
let mut t = TBufferChannel::with_capacity(0, 10);
// first write (all bytes written)
let bytes_to_write_0: [u8; 2] = [0x01, 0x41];
let write_0_result = t.write(&bytes_to_write_0);
assert_eq!(write_0_result.unwrap(), 2);
assert_eq!(t.write_bytes(), &bytes_to_write_0);
// second write (all bytes written, starting at index 2)
let bytes_to_write_1: [u8; 7] = [0x24, 0x41, 0x32, 0x33, 0x11, 0x98, 0xAF];
let write_1_result = t.write(&bytes_to_write_1);
assert_eq!(write_1_result.unwrap(), 7);
assert_eq!(&t.write_bytes()[2..], &bytes_to_write_1);
// third write (only 1 byte written - that's all we have space for)
let bytes_to_write_2: [u8; 3] = [0xBF, 0xDA, 0x98];
let write_2_result = t.write(&bytes_to_write_2);
assert_eq!(write_2_result.unwrap(), 1);
assert_eq!(&t.write_bytes()[9..], &bytes_to_write_2[0..1]); // how does this syntax work?!
// fourth write (no writes are accepted)
let bytes_to_write_3: [u8; 3] = [0xBF, 0xAA, 0xFD];
let write_3_result = t.write(&bytes_to_write_3);
assert_eq!(write_3_result.unwrap(), 0);
// check the full write buffer
let mut expected: Vec<u8> = Vec::with_capacity(10);
expected.extend_from_slice(&bytes_to_write_0);
expected.extend_from_slice(&bytes_to_write_1);
expected.extend_from_slice(&bytes_to_write_2[0..1]);
assert_eq!(t.write_bytes(), &expected[..]);
}
#[test]
fn must_empty_read_buffer() {
let mut t = TBufferChannel::with_capacity(1, 0);
let bytes_to_read: [u8; 1] = [0x01];
let result = t.set_readable_bytes(&bytes_to_read);
assert_eq!(result, 1);
assert_eq!(t.read_bytes(), &bytes_to_read);
t.empty_read_buffer();
assert_eq!(t.read_bytes().len(), 0);
}
#[test]
fn must_allow_readable_bytes_to_be_set_after_read_buffer_emptied() {
let mut t = TBufferChannel::with_capacity(1, 0);
let bytes_to_read_0: [u8; 1] = [0x01];
let result = t.set_readable_bytes(&bytes_to_read_0);
assert_eq!(result, 1);
assert_eq!(t.read_bytes(), &bytes_to_read_0);
t.empty_read_buffer();
assert_eq!(t.read_bytes().len(), 0);
let bytes_to_read_1: [u8; 1] = [0x02];
let result = t.set_readable_bytes(&bytes_to_read_1);
assert_eq!(result, 1);
assert_eq!(t.read_bytes(), &bytes_to_read_1);
}
#[test]
fn must_accept_multiple_reads_until_all_bytes_read() {
let mut t = TBufferChannel::with_capacity(10, 0);
let readable_bytes: [u8; 10] = [0xFF, 0xEE, 0xDD, 0xCC, 0xBB, 0x00, 0x1A, 0x2B, 0x3C, 0x4D];
// check that we're able to set the bytes to be read
let result = t.set_readable_bytes(&readable_bytes);
assert_eq!(result, 10);
assert_eq!(t.read_bytes(), &readable_bytes);
// first read
let mut read_buf_0 = vec![0; 5];
let read_result = t.read(&mut read_buf_0);
assert_eq!(read_result.unwrap(), 5);
assert_eq!(read_buf_0.as_slice(), &(readable_bytes[0..5]));
// second read
let mut read_buf_1 = vec![0; 4];
let read_result = t.read(&mut read_buf_1);
assert_eq!(read_result.unwrap(), 4);
assert_eq!(read_buf_1.as_slice(), &(readable_bytes[5..9]));
// third read (only 1 byte remains to be read)
let mut read_buf_2 = vec![0; 3];
let read_result = t.read(&mut read_buf_2);
assert_eq!(read_result.unwrap(), 1);
read_buf_2.truncate(1); // FIXME: does the caller have to do this?
assert_eq!(read_buf_2.as_slice(), &(readable_bytes[9..]));
// fourth read (nothing should be readable)
let mut read_buf_3 = vec![0; 10];
let read_result = t.read(&mut read_buf_3);
assert_eq!(read_result.unwrap(), 0);
read_buf_3.truncate(0);
// check that all the bytes we received match the original (again!)
let mut bytes_read = Vec::with_capacity(10);
bytes_read.extend_from_slice(&read_buf_0);
bytes_read.extend_from_slice(&read_buf_1);
bytes_read.extend_from_slice(&read_buf_2);
bytes_read.extend_from_slice(&read_buf_3);
assert_eq!(&bytes_read, &readable_bytes);
}
#[test]
fn must_allow_reads_to_succeed_after_read_buffer_replenished() {
let mut t = TBufferChannel::with_capacity(3, 0);
let readable_bytes_0: [u8; 3] = [0x02, 0xAB, 0x33];
// check that we're able to set the bytes to be read
let result = t.set_readable_bytes(&readable_bytes_0);
assert_eq!(result, 3);
assert_eq!(t.read_bytes(), &readable_bytes_0);
let mut read_buf = vec![0; 4];
// drain the read buffer
let read_result = t.read(&mut read_buf);
assert_eq!(read_result.unwrap(), 3);
assert_eq!(t.read_bytes(), &read_buf[0..3]);
// check that a subsequent read fails
let read_result = t.read(&mut read_buf);
assert_eq!(read_result.unwrap(), 0);
// we don't modify the read buffer on failure
let mut expected_bytes = Vec::with_capacity(4);
expected_bytes.extend_from_slice(&readable_bytes_0);
expected_bytes.push(0x00);
assert_eq!(&read_buf, &expected_bytes);
// replenish the read buffer again
let readable_bytes_1: [u8; 2] = [0x91, 0xAA];
// check that we're able to set the bytes to be read
let result = t.set_readable_bytes(&readable_bytes_1);
assert_eq!(result, 2);
assert_eq!(t.read_bytes(), &readable_bytes_1);
// read again
let read_result = t.read(&mut read_buf);
assert_eq!(read_result.unwrap(), 2);
assert_eq!(t.read_bytes(), &read_buf[0..2]);
}
}

View file

@ -0,0 +1,280 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
//! Types used to send and receive bytes over an I/O channel.
//!
//! The core types are the `TReadTransport`, `TWriteTransport` and the
//! `TIoChannel` traits, through which `TInputProtocol` or
//! `TOutputProtocol` can receive and send primitives over the wire. While
//! `TInputProtocol` and `TOutputProtocol` instances deal with language primitives
//! the types in this module understand only bytes.
use std::io;
use std::io::{Read, Write};
use std::ops::{Deref, DerefMut};
#[cfg(test)]
macro_rules! assert_eq_transport_num_written_bytes {
($transport:ident, $num_written_bytes:expr) => {
{
assert_eq!($transport.channel.write_bytes().len(), $num_written_bytes);
}
};
}
#[cfg(test)]
macro_rules! assert_eq_transport_written_bytes {
($transport:ident, $expected_bytes:ident) => {
{
assert_eq!($transport.channel.write_bytes(), &$expected_bytes);
}
};
}
mod buffered;
mod framed;
mod socket;
mod mem;
pub use self::buffered::{TBufferedReadTransport, TBufferedReadTransportFactory,
TBufferedWriteTransport, TBufferedWriteTransportFactory};
pub use self::framed::{TFramedReadTransport, TFramedReadTransportFactory, TFramedWriteTransport,
TFramedWriteTransportFactory};
pub use self::mem::TBufferChannel;
pub use self::socket::TTcpChannel;
/// Identifies a transport used by a `TInputProtocol` to receive bytes.
pub trait TReadTransport: Read {}
/// Helper type used by a server to create `TReadTransport` instances for
/// accepted client connections.
pub trait TReadTransportFactory {
/// Create a `TTransport` that wraps a channel over which bytes are to be read.
fn create(&self, channel: Box<Read + Send>) -> Box<TReadTransport + Send>;
}
/// Identifies a transport used by `TOutputProtocol` to send bytes.
pub trait TWriteTransport: Write {}
/// Helper type used by a server to create `TWriteTransport` instances for
/// accepted client connections.
pub trait TWriteTransportFactory {
/// Create a `TTransport` that wraps a channel over which bytes are to be sent.
fn create(&self, channel: Box<Write + Send>) -> Box<TWriteTransport + Send>;
}
impl<T> TReadTransport for T
where
T: Read,
{
}
impl<T> TWriteTransport for T
where
T: Write,
{
}
// FIXME: implement the Debug trait for boxed transports
impl<T> TReadTransportFactory for Box<T>
where
T: TReadTransportFactory + ?Sized,
{
fn create(&self, channel: Box<Read + Send>) -> Box<TReadTransport + Send> {
(**self).create(channel)
}
}
impl<T> TWriteTransportFactory for Box<T>
where
T: TWriteTransportFactory + ?Sized,
{
fn create(&self, channel: Box<Write + Send>) -> Box<TWriteTransport + Send> {
(**self).create(channel)
}
}
/// Identifies a splittable bidirectional I/O channel used to send and receive bytes.
pub trait TIoChannel: Read + Write {
/// Split the channel into a readable half and a writable half, where the
/// readable half implements `io::Read` and the writable half implements
/// `io::Write`. Returns `None` if the channel was not initialized, or if it
/// cannot be split safely.
///
/// Returned halves may share the underlying OS channel or buffer resources.
/// Implementations **should ensure** that these two halves can be safely
/// used independently by concurrent threads.
fn split(self) -> ::Result<(::transport::ReadHalf<Self>, ::transport::WriteHalf<Self>)>
where
Self: Sized;
}
/// The readable half of an object returned from `TIoChannel::split`.
#[derive(Debug)]
pub struct ReadHalf<C>
where
C: Read,
{
handle: C,
}
/// The writable half of an object returned from `TIoChannel::split`.
#[derive(Debug)]
pub struct WriteHalf<C>
where
C: Write,
{
handle: C,
}
impl<C> Read for ReadHalf<C>
where
C: Read,
{
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.handle.read(buf)
}
}
impl<C> Write for WriteHalf<C>
where
C: Write,
{
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.handle.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.handle.flush()
}
}
impl<C> Deref for ReadHalf<C>
where
C: Read,
{
type Target = C;
fn deref(&self) -> &Self::Target {
&self.handle
}
}
impl<C> DerefMut for ReadHalf<C>
where
C: Read,
{
fn deref_mut(&mut self) -> &mut C {
&mut self.handle
}
}
impl<C> Deref for WriteHalf<C>
where
C: Write,
{
type Target = C;
fn deref(&self) -> &Self::Target {
&self.handle
}
}
impl<C> DerefMut for WriteHalf<C>
where
C: Write,
{
fn deref_mut(&mut self) -> &mut C {
&mut self.handle
}
}
#[cfg(test)]
mod tests {
use std::io::Cursor;
use super::*;
#[test]
fn must_create_usable_read_channel_from_concrete_read_type() {
let r = Cursor::new([0, 1, 2]);
let _ = TBufferedReadTransport::new(r);
}
#[test]
fn must_create_usable_read_channel_from_boxed_read() {
let r: Box<Read> = Box::new(Cursor::new([0, 1, 2]));
let _ = TBufferedReadTransport::new(r);
}
#[test]
fn must_create_usable_write_channel_from_concrete_write_type() {
let w = vec![0u8; 10];
let _ = TBufferedWriteTransport::new(w);
}
#[test]
fn must_create_usable_write_channel_from_boxed_write() {
let w: Box<Write> = Box::new(vec![0u8; 10]);
let _ = TBufferedWriteTransport::new(w);
}
#[test]
fn must_create_usable_read_transport_from_concrete_read_transport() {
let r = Cursor::new([0, 1, 2]);
let mut t = TBufferedReadTransport::new(r);
takes_read_transport(&mut t)
}
#[test]
fn must_create_usable_read_transport_from_boxed_read() {
let r = Cursor::new([0, 1, 2]);
let mut t: Box<TReadTransport> = Box::new(TBufferedReadTransport::new(r));
takes_read_transport(&mut t)
}
#[test]
fn must_create_usable_write_transport_from_concrete_write_transport() {
let w = vec![0u8; 10];
let mut t = TBufferedWriteTransport::new(w);
takes_write_transport(&mut t)
}
#[test]
fn must_create_usable_write_transport_from_boxed_write() {
let w = vec![0u8; 10];
let mut t: Box<TWriteTransport> = Box::new(TBufferedWriteTransport::new(w));
takes_write_transport(&mut t)
}
fn takes_read_transport<R>(t: &mut R)
where
R: TReadTransport,
{
t.bytes();
}
fn takes_write_transport<W>(t: &mut W)
where
W: TWriteTransport,
{
t.flush().unwrap();
}
}

View file

@ -0,0 +1,165 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
use std::convert::From;
use std::io;
use std::io::{ErrorKind, Read, Write};
use std::net::{Shutdown, TcpStream};
use {TransportErrorKind, new_transport_error};
use super::{ReadHalf, TIoChannel, WriteHalf};
/// Bidirectional TCP/IP channel.
///
/// # Examples
///
/// Create a `TTcpChannel`.
///
/// ```no_run
/// use std::io::{Read, Write};
/// use thrift::transport::TTcpChannel;
///
/// let mut c = TTcpChannel::new();
/// c.open("localhost:9090").unwrap();
///
/// let mut buf = vec![0u8; 4];
/// c.read(&mut buf).unwrap();
/// c.write(&vec![0, 1, 2]).unwrap();
/// ```
///
/// Create a `TTcpChannel` by wrapping an existing `TcpStream`.
///
/// ```no_run
/// use std::io::{Read, Write};
/// use std::net::TcpStream;
/// use thrift::transport::TTcpChannel;
///
/// let stream = TcpStream::connect("127.0.0.1:9189").unwrap();
///
/// // no need to call c.open() since we've already connected above
/// let mut c = TTcpChannel::with_stream(stream);
///
/// let mut buf = vec![0u8; 4];
/// c.read(&mut buf).unwrap();
/// c.write(&vec![0, 1, 2]).unwrap();
/// ```
#[derive(Debug, Default)]
pub struct TTcpChannel {
stream: Option<TcpStream>,
}
impl TTcpChannel {
/// Create an uninitialized `TTcpChannel`.
///
/// The returned instance must be opened using `TTcpChannel::open(...)`
/// before it can be used.
pub fn new() -> TTcpChannel {
TTcpChannel { stream: None }
}
/// Create a `TTcpChannel` that wraps an existing `TcpStream`.
///
/// The passed-in stream is assumed to have been opened before being wrapped
/// by the created `TTcpChannel` instance.
pub fn with_stream(stream: TcpStream) -> TTcpChannel {
TTcpChannel { stream: Some(stream) }
}
/// Connect to `remote_address`, which should have the form `host:port`.
pub fn open(&mut self, remote_address: &str) -> ::Result<()> {
if self.stream.is_some() {
Err(
new_transport_error(
TransportErrorKind::AlreadyOpen,
"tcp connection previously opened",
),
)
} else {
match TcpStream::connect(&remote_address) {
Ok(s) => {
self.stream = Some(s);
Ok(())
}
Err(e) => Err(From::from(e)),
}
}
}
/// Shut down this channel.
///
/// Both send and receive halves are closed, and this instance can no
/// longer be used to communicate with another endpoint.
pub fn close(&mut self) -> ::Result<()> {
self.if_set(|s| s.shutdown(Shutdown::Both))
.map_err(From::from)
}
fn if_set<F, T>(&mut self, mut stream_operation: F) -> io::Result<T>
where
F: FnMut(&mut TcpStream) -> io::Result<T>,
{
if let Some(ref mut s) = self.stream {
stream_operation(s)
} else {
Err(io::Error::new(ErrorKind::NotConnected, "tcp endpoint not connected"),)
}
}
}
impl TIoChannel for TTcpChannel {
fn split(self) -> ::Result<(ReadHalf<Self>, WriteHalf<Self>)>
where
Self: Sized,
{
let mut s = self;
s.stream
.as_mut()
.and_then(|s| s.try_clone().ok())
.map(
|cloned| {
(ReadHalf { handle: TTcpChannel { stream: s.stream.take() } },
WriteHalf { handle: TTcpChannel { stream: Some(cloned) } })
},
)
.ok_or_else(
|| {
new_transport_error(
TransportErrorKind::Unknown,
"cannot clone underlying tcp stream",
)
},
)
}
}
impl Read for TTcpChannel {
fn read(&mut self, b: &mut [u8]) -> io::Result<usize> {
self.if_set(|s| s.read(b))
}
}
impl Write for TTcpChannel {
fn write(&mut self, b: &[u8]) -> io::Result<usize> {
self.if_set(|s| s.write(b))
}
fn flush(&mut self) -> io::Result<()> {
self.if_set(|s| s.flush())
}
}