Moving from govendor to dep, updated dependencies (#48)

* Moving from govendor to dep.

* Making the pull request template more friendly.

* Fixing akward space in PR template.

* goimports run on whole project using ` goimports -w $(find . -type f -name '*.go' -not -path "./vendor/*" -not -path "./gen-go/*")`

source of command: https://gist.github.com/bgentry/fd1ffef7dbde01857f66
This commit is contained in:
Renan DelValle 2018-01-07 13:13:47 -08:00 committed by GitHub
parent 9631aa3aab
commit 8d445c1c77
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2186 changed files with 400410 additions and 352 deletions

View file

@ -0,0 +1,69 @@
%%
%% Licensed to the Apache Software Foundation (ASF) under one
%% or more contributor license agreements. See the NOTICE file
%% distributed with this work for additional information
%% regarding copyright ownership. The ASF licenses this file
%% to you under the Apache License, Version 2.0 (the
%% "License"); you may not use this file except in compliance
%% with the License. You may obtain a copy of the License at
%%
%% http://www.apache.org/licenses/LICENSE-2.0
%%
%% Unless required by applicable law or agreed to in writing,
%% software distributed under the License is distributed on an
%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
%% KIND, either express or implied. See the License for the
%% specific language governing permissions and limitations
%% under the License.
%%
-module(thrift_base64_transport).
-behaviour(thrift_transport).
%% API
-export([new/1, new_transport_factory/1]).
%% thrift_transport callbacks
-export([write/2, read/2, flush/1, close/1]).
%% State
-record(b64_transport, {wrapped}).
-type state() :: #b64_transport{}.
-include("thrift_transport_behaviour.hrl").
new(Wrapped) ->
State = #b64_transport{wrapped = Wrapped},
thrift_transport:new(?MODULE, State).
write(This = #b64_transport{wrapped = Wrapped}, Data) ->
{NewWrapped, Result} = thrift_transport:write(Wrapped, base64:encode(iolist_to_binary(Data))),
{This#b64_transport{wrapped = NewWrapped}, Result}.
%% base64 doesn't support reading quite yet since it would involve
%% nasty buffering and such
read(This = #b64_transport{}, _Data) ->
{This, {error, no_reads_allowed}}.
flush(This = #b64_transport{wrapped = Wrapped0}) ->
{Wrapped1, ok} = thrift_transport:write(Wrapped0, <<"\n">>),
{Wrapped2, ok} = thrift_transport:flush(Wrapped1),
{This#b64_transport{wrapped = Wrapped2}, ok}.
close(This0) ->
{This1 = #b64_transport{wrapped = Wrapped}, ok} = flush(This0),
{NewWrapped, ok} = thrift_transport:close(Wrapped),
{This1#b64_transport{wrapped = NewWrapped}, ok}.
%%%% FACTORY GENERATION %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
new_transport_factory(WrapFactory) ->
F = fun() ->
{ok, Wrapped} = WrapFactory(),
new(Wrapped)
end,
{ok, F}.

View file

@ -0,0 +1,347 @@
%%
%% Licensed to the Apache Software Foundation (ASF) under one
%% or more contributor license agreements. See the NOTICE file
%% distributed with this work for additional information
%% regarding copyright ownership. The ASF licenses this file
%% to you under the Apache License, Version 2.0 (the
%% "License"); you may not use this file except in compliance
%% with the License. You may obtain a copy of the License at
%%
%% http://www.apache.org/licenses/LICENSE-2.0
%%
%% Unless required by applicable law or agreed to in writing,
%% software distributed under the License is distributed on an
%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
%% KIND, either express or implied. See the License for the
%% specific language governing permissions and limitations
%% under the License.
%%
-module(thrift_binary_protocol).
-behaviour(thrift_protocol).
-include("thrift_constants.hrl").
-include("thrift_protocol.hrl").
-export([new/1, new/2,
read/2,
write/2,
flush_transport/1,
close_transport/1,
new_protocol_factory/2
]).
-record(binary_protocol, {transport,
strict_read=true,
strict_write=true
}).
-type state() :: #binary_protocol{}.
-include("thrift_protocol_behaviour.hrl").
-define(VERSION_MASK, 16#FFFF0000).
-define(VERSION_1, 16#80010000).
-define(TYPE_MASK, 16#000000ff).
new(Transport) ->
new(Transport, _Options = []).
new(Transport, Options) ->
State = #binary_protocol{transport = Transport},
State1 = parse_options(Options, State),
thrift_protocol:new(?MODULE, State1).
parse_options([], State) ->
State;
parse_options([{strict_read, Bool} | Rest], State) when is_boolean(Bool) ->
parse_options(Rest, State#binary_protocol{strict_read=Bool});
parse_options([{strict_write, Bool} | Rest], State) when is_boolean(Bool) ->
parse_options(Rest, State#binary_protocol{strict_write=Bool}).
flush_transport(This = #binary_protocol{transport = Transport}) ->
{NewTransport, Result} = thrift_transport:flush(Transport),
{This#binary_protocol{transport = NewTransport}, Result}.
close_transport(This = #binary_protocol{transport = Transport}) ->
{NewTransport, Result} = thrift_transport:close(Transport),
{This#binary_protocol{transport = NewTransport}, Result}.
%%%
%%% instance methods
%%%
write(This0, #protocol_message_begin{
name = Name,
type = Type,
seqid = Seqid}) ->
case This0#binary_protocol.strict_write of
true ->
{This1, ok} = write(This0, {i32, ?VERSION_1 bor Type}),
{This2, ok} = write(This1, {string, Name}),
{This3, ok} = write(This2, {i32, Seqid}),
{This3, ok};
false ->
{This1, ok} = write(This0, {string, Name}),
{This2, ok} = write(This1, {byte, Type}),
{This3, ok} = write(This2, {i32, Seqid}),
{This3, ok}
end;
write(This, message_end) -> {This, ok};
write(This0, #protocol_field_begin{
name = _Name,
type = Type,
id = Id}) ->
{This1, ok} = write(This0, {byte, Type}),
{This2, ok} = write(This1, {i16, Id}),
{This2, ok};
write(This, field_stop) ->
write(This, {byte, ?tType_STOP});
write(This, field_end) -> {This, ok};
write(This0, #protocol_map_begin{
ktype = Ktype,
vtype = Vtype,
size = Size}) ->
{This1, ok} = write(This0, {byte, Ktype}),
{This2, ok} = write(This1, {byte, Vtype}),
{This3, ok} = write(This2, {i32, Size}),
{This3, ok};
write(This, map_end) -> {This, ok};
write(This0, #protocol_list_begin{
etype = Etype,
size = Size}) ->
{This1, ok} = write(This0, {byte, Etype}),
{This2, ok} = write(This1, {i32, Size}),
{This2, ok};
write(This, list_end) -> {This, ok};
write(This0, #protocol_set_begin{
etype = Etype,
size = Size}) ->
{This1, ok} = write(This0, {byte, Etype}),
{This2, ok} = write(This1, {i32, Size}),
{This2, ok};
write(This, set_end) -> {This, ok};
write(This, #protocol_struct_begin{}) -> {This, ok};
write(This, struct_end) -> {This, ok};
write(This, {bool, true}) -> write(This, {byte, 1});
write(This, {bool, false}) -> write(This, {byte, 0});
write(This, {byte, Byte}) ->
write(This, <<Byte:8/big-signed>>);
write(This, {i16, I16}) ->
write(This, <<I16:16/big-signed>>);
write(This, {i32, I32}) ->
write(This, <<I32:32/big-signed>>);
write(This, {i64, I64}) ->
write(This, <<I64:64/big-signed>>);
write(This, {double, Double}) ->
write(This, <<Double:64/big-signed-float>>);
write(This0, {string, Str}) when is_list(Str) ->
{This1, ok} = write(This0, {i32, length(Str)}),
{This2, ok} = write(This1, list_to_binary(Str)),
{This2, ok};
write(This0, {string, Bin}) when is_binary(Bin) ->
{This1, ok} = write(This0, {i32, size(Bin)}),
{This2, ok} = write(This1, Bin),
{This2, ok};
%% Data :: iolist()
write(This = #binary_protocol{transport = Trans}, Data) ->
{NewTransport, Result} = thrift_transport:write(Trans, Data),
{This#binary_protocol{transport = NewTransport}, Result}.
%%
read(This0, message_begin) ->
{This1, Initial} = read(This0, ui32),
case Initial of
{ok, Sz} when Sz band ?VERSION_MASK =:= ?VERSION_1 ->
%% we're at version 1
{This2, {ok, Name}} = read(This1, string),
{This3, {ok, SeqId}} = read(This2, i32),
Type = Sz band ?TYPE_MASK,
{This3, #protocol_message_begin{name = binary_to_list(Name),
type = Type,
seqid = SeqId}};
{ok, Sz} when Sz < 0 ->
%% there's a version number but it's unexpected
{This1, {error, {bad_binary_protocol_version, Sz}}};
{ok, _Sz} when This1#binary_protocol.strict_read =:= true ->
%% strict_read is true and there's no version header; that's an error
{This1, {error, no_binary_protocol_version}};
{ok, Sz} when This1#binary_protocol.strict_read =:= false ->
%% strict_read is false, so just read the old way
{This2, {ok, Name}} = read_data(This1, Sz),
{This3, {ok, Type}} = read(This2, byte),
{This4, {ok, SeqId}} = read(This3, i32),
{This4, #protocol_message_begin{name = binary_to_list(Name),
type = Type,
seqid = SeqId}};
Else ->
{This1, Else}
end;
read(This, message_end) -> {This, ok};
read(This, struct_begin) -> {This, ok};
read(This, struct_end) -> {This, ok};
read(This0, field_begin) ->
{This1, Result} = read(This0, byte),
case Result of
{ok, Type = ?tType_STOP} ->
{This1, #protocol_field_begin{type = Type}};
{ok, Type} ->
{This2, {ok, Id}} = read(This1, i16),
{This2, #protocol_field_begin{type = Type,
id = Id}}
end;
read(This, field_end) -> {This, ok};
read(This0, map_begin) ->
{This1, {ok, Ktype}} = read(This0, byte),
{This2, {ok, Vtype}} = read(This1, byte),
{This3, {ok, Size}} = read(This2, i32),
{This3, #protocol_map_begin{ktype = Ktype,
vtype = Vtype,
size = Size}};
read(This, map_end) -> {This, ok};
read(This0, list_begin) ->
{This1, {ok, Etype}} = read(This0, byte),
{This2, {ok, Size}} = read(This1, i32),
{This2, #protocol_list_begin{etype = Etype,
size = Size}};
read(This, list_end) -> {This, ok};
read(This0, set_begin) ->
{This1, {ok, Etype}} = read(This0, byte),
{This2, {ok, Size}} = read(This1, i32),
{This2, #protocol_set_begin{etype = Etype,
size = Size}};
read(This, set_end) -> {This, ok};
read(This0, field_stop) ->
{This1, {ok, ?tType_STOP}} = read(This0, byte),
{This1, ok};
%%
read(This0, bool) ->
{This1, Result} = read(This0, byte),
case Result of
{ok, Byte} -> {This1, {ok, Byte /= 0}};
Else -> {This1, Else}
end;
read(This0, byte) ->
{This1, Bytes} = read_data(This0, 1),
case Bytes of
{ok, <<Val:8/integer-signed-big, _/binary>>} -> {This1, {ok, Val}};
Else -> {This1, Else}
end;
read(This0, i16) ->
{This1, Bytes} = read_data(This0, 2),
case Bytes of
{ok, <<Val:16/integer-signed-big, _/binary>>} -> {This1, {ok, Val}};
Else -> {This1, Else}
end;
read(This0, i32) ->
{This1, Bytes} = read_data(This0, 4),
case Bytes of
{ok, <<Val:32/integer-signed-big, _/binary>>} -> {This1, {ok, Val}};
Else -> {This1, Else}
end;
%% unsigned ints aren't used by thrift itself, but it's used for the parsing
%% of the packet version header. Without this special function BEAM works fine
%% but hipe thinks it received a bad version header.
read(This0, ui32) ->
{This1, Bytes} = read_data(This0, 4),
case Bytes of
{ok, <<Val:32/integer-unsigned-big, _/binary>>} -> {This1, {ok, Val}};
Else -> {This1, Else}
end;
read(This0, i64) ->
{This1, Bytes} = read_data(This0, 8),
case Bytes of
{ok, <<Val:64/integer-signed-big, _/binary>>} -> {This1, {ok, Val}};
Else -> {This1, Else}
end;
read(This0, double) ->
{This1, Bytes} = read_data(This0, 8),
case Bytes of
{ok, <<Val:64/float-signed-big, _/binary>>} -> {This1, {ok, Val}};
Else -> {This1, Else}
end;
% returns a binary directly, call binary_to_list if necessary
read(This0, string) ->
{This1, {ok, Sz}} = read(This0, i32),
read_data(This1, Sz).
-spec read_data(#binary_protocol{}, non_neg_integer()) ->
{#binary_protocol{}, {ok, binary()} | {error, _Reason}}.
read_data(This, 0) -> {This, {ok, <<>>}};
read_data(This = #binary_protocol{transport = Trans}, Len) when is_integer(Len) andalso Len > 0 ->
{NewTransport, Result} = thrift_transport:read(Trans, Len),
{This#binary_protocol{transport = NewTransport}, Result}.
%%%% FACTORY GENERATION %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
-record(tbp_opts, {strict_read = true,
strict_write = true}).
parse_factory_options([], Opts) ->
Opts;
parse_factory_options([{strict_read, Bool} | Rest], Opts) when is_boolean(Bool) ->
parse_factory_options(Rest, Opts#tbp_opts{strict_read=Bool});
parse_factory_options([{strict_write, Bool} | Rest], Opts) when is_boolean(Bool) ->
parse_factory_options(Rest, Opts#tbp_opts{strict_write=Bool}).
%% returns a (fun() -> thrift_protocol())
new_protocol_factory(TransportFactory, Options) ->
ParsedOpts = parse_factory_options(Options, #tbp_opts{}),
F = fun() ->
case TransportFactory() of
{ok, Transport} ->
thrift_binary_protocol:new(
Transport,
[{strict_read, ParsedOpts#tbp_opts.strict_read},
{strict_write, ParsedOpts#tbp_opts.strict_write}]);
{error, Error} ->
{error, Error}
end
end,
{ok, F}.

View file

@ -0,0 +1,98 @@
%%
%% Licensed to the Apache Software Foundation (ASF) under one
%% or more contributor license agreements. See the NOTICE file
%% distributed with this work for additional information
%% regarding copyright ownership. The ASF licenses this file
%% to you under the Apache License, Version 2.0 (the
%% "License"); you may not use this file except in compliance
%% with the License. You may obtain a copy of the License at
%%
%% http://www.apache.org/licenses/LICENSE-2.0
%%
%% Unless required by applicable law or agreed to in writing,
%% software distributed under the License is distributed on an
%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
%% KIND, either express or implied. See the License for the
%% specific language governing permissions and limitations
%% under the License.
%%
-module(thrift_buffered_transport).
-behaviour(thrift_transport).
%% constructor
-export([new/1]).
%% protocol callbacks
-export([read/2, read_exact/2, write/2, flush/1, close/1]).
%% legacy api
-export([new_transport_factory/1]).
-record(t_buffered, {
wrapped,
write_buffer
}).
-type state() :: #t_buffered{}.
-spec new(Transport::thrift_transport:t_transport()) ->
thrift_transport:t_transport().
new(Wrapped) ->
State = #t_buffered{
wrapped = Wrapped,
write_buffer = []
},
thrift_transport:new(?MODULE, State).
-include("thrift_transport_behaviour.hrl").
%% reads data through from the wrapped transport
read(State = #t_buffered{wrapped = Wrapped}, Len)
when is_integer(Len), Len >= 0 ->
{NewState, Response} = thrift_transport:read(Wrapped, Len),
{State#t_buffered{wrapped = NewState}, Response}.
%% reads data through from the wrapped transport
read_exact(State = #t_buffered{wrapped = Wrapped}, Len)
when is_integer(Len), Len >= 0 ->
{NewState, Response} = thrift_transport:read_exact(Wrapped, Len),
{State#t_buffered{wrapped = NewState}, Response}.
write(State = #t_buffered{write_buffer = Buffer}, Data) ->
{State#t_buffered{write_buffer = [Buffer, Data]}, ok}.
flush(State = #t_buffered{wrapped = Wrapped, write_buffer = Buffer}) ->
case iolist_size(Buffer) of
%% if write buffer is empty, do nothing
0 -> {State, ok};
_ ->
{Written, Response} = thrift_transport:write(Wrapped, Buffer),
{Flushed, ok} = thrift_transport:flush(Written),
{State#t_buffered{wrapped = Flushed, write_buffer = []}, Response}
end.
close(State = #t_buffered{wrapped = Wrapped}) ->
{Closed, Result} = thrift_transport:close(Wrapped),
{State#t_buffered{wrapped = Closed}, Result}.
%%--------------------------------------------------------------------
%%% Internal functions
%%--------------------------------------------------------------------
%%%% FACTORY GENERATION %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
new_transport_factory(WrapFactory) ->
F = fun() ->
{ok, Wrapped} = WrapFactory(),
new(Wrapped)
end,
{ok, F}.

View file

@ -0,0 +1,162 @@
%%
%% Licensed to the Apache Software Foundation (ASF) under one
%% or more contributor license agreements. See the NOTICE file
%% distributed with this work for additional information
%% regarding copyright ownership. The ASF licenses this file
%% to you under the Apache License, Version 2.0 (the
%% "License"); you may not use this file except in compliance
%% with the License. You may obtain a copy of the License at
%%
%% http://www.apache.org/licenses/LICENSE-2.0
%%
%% Unless required by applicable law or agreed to in writing,
%% software distributed under the License is distributed on an
%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
%% KIND, either express or implied. See the License for the
%% specific language governing permissions and limitations
%% under the License.
%%
-module(thrift_client).
%% API
-export([new/2, call/3, send_call/3, close/1]).
-include("thrift_constants.hrl").
-include("thrift_protocol.hrl").
-record(tclient, {service, protocol, seqid}).
new(Protocol, Service)
when is_atom(Service) ->
{ok, #tclient{protocol = Protocol,
service = Service,
seqid = 0}}.
-spec call(#tclient{}, atom(), list()) -> {#tclient{}, {ok, any()} | {error, any()}}.
call(Client = #tclient{}, Function, Args)
when is_atom(Function), is_list(Args) ->
case send_function_call(Client, Function, Args) of
{ok, Client1} -> receive_function_result(Client1, Function);
{{error, X}, Client1} -> {Client1, {error, X}};
Else -> Else
end.
%% Sends a function call but does not read the result. This is useful
%% if you're trying to log non-oneway function calls to write-only
%% transports like thrift_disk_log_transport.
-spec send_call(#tclient{}, atom(), list()) -> {#tclient{}, ok}.
send_call(Client = #tclient{}, Function, Args)
when is_atom(Function), is_list(Args) ->
case send_function_call(Client, Function, Args) of
{ok, Client1} -> {Client1, ok};
Else -> Else
end.
-spec close(#tclient{}) -> ok.
close(#tclient{protocol=Protocol}) ->
thrift_protocol:close_transport(Protocol).
%%--------------------------------------------------------------------
%%% Internal functions
%%--------------------------------------------------------------------
-spec send_function_call(#tclient{}, atom(), list()) -> {ok | {error, any()}, #tclient{}}.
send_function_call(Client = #tclient{service = Service}, Function, Args) ->
{Params, Reply} = try
{Service:function_info(Function, params_type), Service:function_info(Function, reply_type)}
catch error:function_clause -> {no_function, 0}
end,
MsgType = case Reply of
oneway_void -> ?tMessageType_ONEWAY;
_ -> ?tMessageType_CALL
end,
case Params of
no_function ->
{{error, {no_function, Function}}, Client};
{struct, PList} when length(PList) =/= length(Args) ->
{{error, {bad_args, Function, Args}}, Client};
{struct, _PList} -> write_message(Client, Function, Args, Params, MsgType)
end.
-spec write_message(#tclient{}, atom(), list(), {struct, list()}, integer()) ->
{ok | {error, any()}, #tclient{}}.
write_message(Client = #tclient{protocol = P0, seqid = Seq}, Function, Args, Params, MsgType) ->
try
{P1, ok} = thrift_protocol:write(P0, #protocol_message_begin{
name = atom_to_list(Function),
type = MsgType,
seqid = Seq
}),
{P2, ok} = thrift_protocol:write(P1, {Params, list_to_tuple([Function|Args])}),
{P3, ok} = thrift_protocol:write(P2, message_end),
{P4, ok} = thrift_protocol:flush_transport(P3),
{ok, Client#tclient{protocol = P4}}
catch
error:{badmatch, {_, {error, _} = Error}} -> {Error, Client}
end.
-spec receive_function_result(#tclient{}, atom()) -> {#tclient{}, {ok, any()} | {error, any()}}.
receive_function_result(Client = #tclient{service = Service}, Function) ->
ResultType = Service:function_info(Function, reply_type),
read_result(Client, Function, ResultType).
read_result(Client, _Function, oneway_void) ->
{Client, {ok, ok}};
read_result(Client = #tclient{protocol = Proto0,
seqid = SeqId},
Function,
ReplyType) ->
case thrift_protocol:read(Proto0, message_begin) of
{Proto1, {error, Reason}} ->
NewClient = Client#tclient{protocol = Proto1},
{NewClient, {error, Reason}};
{Proto1, MessageBegin} ->
NewClient = Client#tclient{protocol = Proto1},
case MessageBegin of
#protocol_message_begin{seqid = RetSeqId} when RetSeqId =/= SeqId ->
{NewClient, {error, {bad_seq_id, SeqId}}};
#protocol_message_begin{type = ?tMessageType_EXCEPTION} ->
handle_application_exception(NewClient);
#protocol_message_begin{type = ?tMessageType_REPLY} ->
handle_reply(NewClient, Function, ReplyType)
end
end.
handle_reply(Client = #tclient{protocol = Proto0,
service = Service},
Function,
ReplyType) ->
{struct, ExceptionFields} = Service:function_info(Function, exceptions),
ReplyStructDef = {struct, [{0, ReplyType}] ++ ExceptionFields},
{Proto1, {ok, Reply}} = thrift_protocol:read(Proto0, ReplyStructDef),
{Proto2, ok} = thrift_protocol:read(Proto1, message_end),
NewClient = Client#tclient{protocol = Proto2},
ReplyList = tuple_to_list(Reply),
true = length(ReplyList) == length(ExceptionFields) + 1,
ExceptionVals = tl(ReplyList),
Thrown = [X || X <- ExceptionVals,
X =/= undefined],
case Thrown of
[] when ReplyType == {struct, []} ->
{NewClient, {ok, ok}};
[] ->
{NewClient, {ok, hd(ReplyList)}};
[Exception] ->
throw({NewClient, {exception, Exception}})
end.
handle_application_exception(Client = #tclient{protocol = Proto0}) ->
{Proto1, {ok, Exception}} =
thrift_protocol:read(Proto0, ?TApplicationException_Structure),
{Proto2, ok} = thrift_protocol:read(Proto1, message_end),
XRecord = list_to_tuple(
['TApplicationException' | tuple_to_list(Exception)]),
error_logger:error_msg("X: ~p~n", [XRecord]),
true = is_record(XRecord, 'TApplicationException'),
NewClient = Client#tclient{protocol = Proto2},
throw({NewClient, {exception, XRecord}}).

View file

@ -0,0 +1,112 @@
%%
%% Licensed to the Apache Software Foundation (ASF) under one
%% or more contributor license agreements. See the NOTICE file
%% distributed with this work for additional information
%% regarding copyright ownership. The ASF licenses this file
%% to you under the Apache License, Version 2.0 (the
%% "License"); you may not use this file except in compliance
%% with the License. You may obtain a copy of the License at
%%
%% http://www.apache.org/licenses/LICENSE-2.0
%%
%% Unless required by applicable law or agreed to in writing,
%% software distributed under the License is distributed on an
%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
%% KIND, either express or implied. See the License for the
%% specific language governing permissions and limitations
%% under the License.
%%
-module(thrift_client_util).
-export([new/4]).
-export([new_multiplexed/3, new_multiplexed/4]).
-type service_name() :: nonempty_string().
-type service_module() :: atom().
-type multiplexed_service_map() :: [{ServiceName::service_name(), ServiceModule::service_module()}].
%%
%% Splits client options into client, protocol, and transport options
%%
%% split_options([Options...]) -> {ProtocolOptions, TransportOptions}
%%
split_options(Options) ->
split_options(Options, [], []).
split_options([], ProtoIn, TransIn) ->
{ProtoIn, TransIn};
split_options([Opt = {OptKey, _} | Rest], ProtoIn, TransIn)
when OptKey =:= strict_read;
OptKey =:= strict_write;
OptKey =:= protocol ->
split_options(Rest, [Opt | ProtoIn], TransIn);
split_options([Opt = {OptKey, _} | Rest], ProtoIn, TransIn)
when OptKey =:= framed;
OptKey =:= connect_timeout;
OptKey =:= recv_timeout;
OptKey =:= sockopts;
OptKey =:= ssltransport;
OptKey =:= ssloptions->
split_options(Rest, ProtoIn, [Opt | TransIn]).
%% Client constructor for the common-case of socket transports
new(Host, Port, Service, Options)
when is_integer(Port), is_atom(Service), is_list(Options) ->
{ProtoOpts, TransOpts0} = split_options(Options),
{TransportModule, TransOpts2} = case lists:keytake(ssltransport, 1, TransOpts0) of
{value, {_, true}, TransOpts1} -> {thrift_sslsocket_transport, TransOpts1};
false -> {thrift_socket_transport, TransOpts0}
end,
{ProtocolModule, ProtoOpts1} = case lists:keytake(protocol, 1, ProtoOpts) of
{value, {_, compact}, Opts} -> {thrift_compact_protocol, Opts};
{value, {_, json}, Opts} -> {thrift_json_protocol, Opts};
{value, {_, binary}, Opts} -> {thrift_binary_protocol, Opts};
false -> {thrift_binary_protocol, ProtoOpts}
end,
{ok, TransportFactory} =
TransportModule:new_transport_factory(Host, Port, TransOpts2),
{ok, ProtocolFactory} = ProtocolModule:new_protocol_factory(
TransportFactory, ProtoOpts1),
case ProtocolFactory() of
{ok, Protocol} ->
thrift_client:new(Protocol, Service);
{error, Error} ->
{error, Error}
end.
-spec new_multiplexed(Host, Port, Services, Options) -> {ok, ServiceThriftClientList} when
Host :: nonempty_string(),
Port :: non_neg_integer(),
Services :: multiplexed_service_map(),
Options :: list(),
ServiceThriftClientList :: [{ServiceName::list(), ThriftClient::term()}].
new_multiplexed(Host, Port, Services, Options) when is_integer(Port),
is_list(Services),
is_list(Options) ->
new_multiplexed(thrift_socket_transport:new_transport_factory(Host, Port, Options), Services, Options).
-spec new_multiplexed(TransportFactoryTuple, Services, Options) -> {ok, ServiceThriftClientList} when
TransportFactoryTuple :: {ok, TransportFactory::term()},
Services :: multiplexed_service_map(),
Options :: list(),
ServiceThriftClientList :: [{ServiceName::service_name(), ThriftClient::term()}].
new_multiplexed(TransportFactoryTuple, Services, Options) when is_list(Services),
is_list(Options),
is_tuple(TransportFactoryTuple) ->
{ProtoOpts, _} = split_options(Options),
{ok, TransportFactory} = TransportFactoryTuple,
{ok, ProtocolFactory} = thrift_binary_protocol:new_protocol_factory(TransportFactory, ProtoOpts),
{ok, Protocol} = ProtocolFactory(),
{ok, [{ServiceName, element(2, thrift_client:new(element(2, thrift_multiplexed_protocol:new(Protocol, ServiceName)), Service))} || {ServiceName, Service} <- Services]}.

View file

@ -0,0 +1,390 @@
%%
%% Licensed to the Apache Software Foundation (ASF) under one
%% or more contributor license agreements. See the NOTICE file
%% distributed with this work for additional information
%% regarding copyright ownership. The ASF licenses this file
%% to you under the Apache License, Version 2.0 (the
%% "License"); you may not use this file except in compliance
%% with the License. You may obtain a copy of the License at
%%
%% http://www.apache.org/licenses/LICENSE-2.0
%%
%% Unless required by applicable law or agreed to in writing,
%% software distributed under the License is distributed on an
%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
%% KIND, either express or implied. See the License for the
%% specific language governing permissions and limitations
%% under the License.
%%
-module(thrift_compact_protocol).
-behaviour(thrift_protocol).
-include("thrift_constants.hrl").
-include("thrift_protocol.hrl").
-export([new/1, new/2,
read/2,
write/2,
flush_transport/1,
close_transport/1,
new_protocol_factory/2
]).
-define(ID_NONE, 16#10000).
-define(CBOOL_NONE, 0).
-define(CBOOL_TRUE, 1).
-define(CBOOL_FALSE, 2).
-record(t_compact, {transport,
% state for pending boolean fields
read_stack=[],
read_value=?CBOOL_NONE,
write_stack=[],
write_id=?ID_NONE
}).
-type state() :: #t_compact{}.
-include("thrift_protocol_behaviour.hrl").
-define(PROTOCOL_ID, 16#82).
-define(VERSION_MASK, 16#1f).
-define(VERSION_1, 16#01).
-define(TYPE_MASK, 16#E0).
-define(TYPE_BITS, 16#07).
-define(TYPE_SHIFT_AMOUNT, 5).
typeid_to_compact(?tType_STOP) -> 16#0;
typeid_to_compact(?tType_BOOL) -> 16#2;
typeid_to_compact(?tType_I8) -> 16#3;
typeid_to_compact(?tType_I16) -> 16#4;
typeid_to_compact(?tType_I32) -> 16#5;
typeid_to_compact(?tType_I64) -> 16#6;
typeid_to_compact(?tType_DOUBLE) -> 16#7;
typeid_to_compact(?tType_STRING) -> 16#8;
typeid_to_compact(?tType_STRUCT) -> 16#C;
typeid_to_compact(?tType_MAP) -> 16#B;
typeid_to_compact(?tType_SET) -> 16#A;
typeid_to_compact(?tType_LIST) -> 16#9.
compact_to_typeid(16#0) -> ?tType_STOP;
compact_to_typeid(?CBOOL_FALSE) -> ?tType_BOOL;
compact_to_typeid(?CBOOL_TRUE) -> ?tType_BOOL;
compact_to_typeid(16#7) -> ?tType_DOUBLE;
compact_to_typeid(16#3) -> ?tType_I8;
compact_to_typeid(16#4) -> ?tType_I16;
compact_to_typeid(16#5) -> ?tType_I32;
compact_to_typeid(16#6) -> ?tType_I64;
compact_to_typeid(16#8) -> ?tType_STRING;
compact_to_typeid(16#C) -> ?tType_STRUCT;
compact_to_typeid(16#B) -> ?tType_MAP;
compact_to_typeid(16#A) -> ?tType_SET;
compact_to_typeid(16#9) -> ?tType_LIST.
bool_to_cbool(Value) when Value -> ?CBOOL_TRUE;
bool_to_cbool(_) -> ?CBOOL_FALSE.
cbool_to_bool(Value) -> Value =:= ?CBOOL_TRUE.
new(Transport) -> new(Transport, _Options = []).
new(Transport, _Options) ->
State = #t_compact{transport = Transport},
thrift_protocol:new(?MODULE, State).
flush_transport(This = #t_compact{transport = Transport}) ->
{NewTransport, Result} = thrift_transport:flush(Transport),
{This#t_compact{transport = NewTransport}, Result}.
close_transport(This = #t_compact{transport = Transport}) ->
{NewTransport, Result} = thrift_transport:close(Transport),
{This#t_compact{transport = NewTransport}, Result}.
%%%
%%% instance methods
%%%
write_field_begin(This0 = #t_compact{write_stack=[LastId|T]}, CompactType, Id) ->
IdDiff = Id - LastId,
This1 = This0#t_compact{write_stack=[Id|T]},
case (IdDiff > 0) and (IdDiff < 16) of
true -> write(This1, {byte, (IdDiff bsl 4) bor CompactType});
false ->
{This2, ok} = write(This1, {byte, CompactType}),
write(This2, {i16, Id})
end.
-spec to_zigzag(integer()) -> non_neg_integer().
to_zigzag(Value) -> 16#FFFFFFFFFFFFFFFF band ((Value bsl 1) bxor (Value bsr 63)).
-spec from_zigzag(non_neg_integer()) -> integer().
from_zigzag(Value) -> (Value bsr 1) bxor -(Value band 1).
-spec to_varint(non_neg_integer(), iolist()) -> iolist().
to_varint(Value, Acc) when (Value < 16#80) -> [Acc, Value];
to_varint(Value, Acc) ->
to_varint(Value bsr 7, [Acc, ((Value band 16#7F) bor 16#80)]).
-spec read_varint(#t_compact{}, non_neg_integer(), non_neg_integer()) -> non_neg_integer().
read_varint(This0, Acc, Count) ->
{This1, {ok, Byte}} = read(This0, byte),
case (Byte band 16#80) of
0 -> {This1, {ok, (Byte bsl (7 * Count)) + Acc}};
_ -> read_varint(This1, ((Byte band 16#7f) bsl (7 * Count)) + Acc, Count + 1)
end.
write(This0, #protocol_message_begin{
name = Name,
type = Type,
seqid = Seqid}) ->
{This1, ok} = write(This0, {byte, ?PROTOCOL_ID}),
{This2, ok} = write(This1, {byte, (?VERSION_1 band ?VERSION_MASK) bor (Type bsl ?TYPE_SHIFT_AMOUNT)}),
{This3, ok} = write(This2, {ui32, Seqid}),
{This4, ok} = write(This3, {string, Name}),
{This4, ok};
write(This, message_end) -> {This, ok};
write(This0, #protocol_field_begin{
name = _Name,
type = Type,
id = Id})
when (Type =:= ?tType_BOOL) -> {This0#t_compact{write_id = Id}, ok};
write(This0, #protocol_field_begin{
name = _Name,
type = Type,
id = Id}) ->
write_field_begin(This0, typeid_to_compact(Type), Id);
write(This, field_stop) -> write(This, {byte, ?tType_STOP});
write(This, field_end) -> {This, ok};
write(This0, #protocol_map_begin{
ktype = _Ktype,
vtype = _Vtype,
size = Size})
when Size =:= 0 ->
write(This0, {byte, 0});
write(This0, #protocol_map_begin{
ktype = Ktype,
vtype = Vtype,
size = Size}) ->
{This1, ok} = write(This0, {ui32, Size}),
write(This1, {byte, (typeid_to_compact(Ktype) bsl 4) bor typeid_to_compact(Vtype)});
write(This, map_end) -> {This, ok};
write(This0, #protocol_list_begin{
etype = Etype,
size = Size})
when Size < 16#f ->
write(This0, {byte, (Size bsl 4) bor typeid_to_compact(Etype)});
write(This0, #protocol_list_begin{
etype = Etype,
size = Size}) ->
{This1, ok} = write(This0, {byte, 16#f0 bor typeid_to_compact(Etype)}),
write(This1, {ui32, Size});
write(This, list_end) -> {This, ok};
write(This0, #protocol_set_begin{
etype = Etype,
size = Size}) ->
write(This0, #protocol_list_begin{etype = Etype, size = Size});
write(This, set_end) -> {This, ok};
write(This = #t_compact{write_stack = Stack}, #protocol_struct_begin{}) ->
{This#t_compact{write_stack = [0|Stack]}, ok};
write(This = #t_compact{write_stack = [_|T]}, struct_end) ->
{This#t_compact{write_stack = T}, ok};
write(This = #t_compact{write_id = ?ID_NONE}, {bool, Value}) ->
write(This, {byte, bool_to_cbool(Value)});
write(This0 = #t_compact{write_id = Id}, {bool, Value}) ->
{This1, ok} = write_field_begin(This0, bool_to_cbool(Value), Id),
{This1#t_compact{write_id = ?ID_NONE}, ok};
write(This, {byte, Value}) when is_integer(Value) ->
write(This, <<Value:8/big-signed>>);
write(This, {i16, Value}) when is_integer(Value) -> write(This, to_varint(to_zigzag(Value), []));
write(This, {ui32, Value}) when is_integer(Value) -> write(This, to_varint(Value, []));
write(This, {i32, Value}) when is_integer(Value) ->
write(This, to_varint(to_zigzag(Value), []));
write(This, {i64, Value}) when is_integer(Value) -> write(This, to_varint(to_zigzag(Value), []));
write(This, {double, Double}) ->
write(This, <<Double:64/float-signed-little>>);
write(This0, {string, Str}) when is_list(Str) ->
% TODO: limit length
{This1, ok} = write(This0, {ui32, length(Str)}),
{This2, ok} = write(This1, list_to_binary(Str)),
{This2, ok};
write(This0, {string, Bin}) when is_binary(Bin) ->
% TODO: limit length
{This1, ok} = write(This0, {ui32, size(Bin)}),
{This2, ok} = write(This1, Bin),
{This2, ok};
%% Data :: iolist()
write(This = #t_compact{transport = Trans}, Data) ->
{NewTransport, Result} = thrift_transport:write(Trans, Data),
{This#t_compact{transport = NewTransport}, Result}.
%%
%%
read(This0, message_begin) ->
{This1, {ok, ?PROTOCOL_ID}} = read(This0, ubyte),
{This2, {ok, VerAndType}} = read(This1, ubyte),
?VERSION_1 = VerAndType band ?VERSION_MASK,
{This3, {ok, SeqId}} = read(This2, ui32),
{This4, {ok, Name}} = read(This3, string),
{This4, #protocol_message_begin{
name = binary_to_list(Name),
type = (VerAndType bsr ?TYPE_SHIFT_AMOUNT) band ?TYPE_BITS,
seqid = SeqId}};
read(This, message_end) -> {This, ok};
read(This = #t_compact{read_stack = Stack}, struct_begin) ->
{This#t_compact{read_stack = [0|Stack]}, ok};
read(This = #t_compact{read_stack = [_H|T]}, struct_end) ->
{This#t_compact{read_stack = T}, ok};
read(This0 = #t_compact{read_stack = [LastId|T]}, field_begin) ->
{This1, {ok, Byte}} = read(This0, ubyte),
case Byte band 16#f of
CompactType = ?tType_STOP ->
{This1, #protocol_field_begin{type = CompactType}};
CompactType ->
{This2, {ok, Id}} = case Byte bsr 4 of
0 -> read(This1, i16);
IdDiff ->
{This1, {ok, LastId + IdDiff}}
end,
case compact_to_typeid(CompactType) of
?tType_BOOL ->
{This2#t_compact{read_stack = [Id|T], read_value = cbool_to_bool(CompactType)},
#protocol_field_begin{type = ?tType_BOOL, id = Id}};
Type ->
{This2#t_compact{read_stack = [Id|T]},
#protocol_field_begin{type = Type, id = Id}}
end
end;
read(This, field_end) -> {This, ok};
read(This0, map_begin) ->
{This1, {ok, Size}} = read(This0, ui32),
{This2, {ok, KV}} = case Size of
0 -> {This1, {ok, 0}};
_ -> read(This1, ubyte)
end,
{This2, #protocol_map_begin{ktype = compact_to_typeid(KV bsr 4),
vtype = compact_to_typeid(KV band 16#f),
size = Size}};
read(This, map_end) -> {This, ok};
read(This0, list_begin) ->
{This1, {ok, SizeAndType}} = read(This0, ubyte),
{This2, {ok, Size}} = case (SizeAndType bsr 4) band 16#f of
16#f -> read(This1, ui32);
Else -> {This1, {ok, Else}}
end,
{This2, #protocol_list_begin{etype = compact_to_typeid(SizeAndType band 16#f),
size = Size}};
read(This, list_end) -> {This, ok};
read(This0, set_begin) ->
{This1, {ok, SizeAndType}} = read(This0, ubyte),
{This2, {ok, Size}} = case (SizeAndType bsr 4) band 16#f of
16#f -> read(This1, ui32);
Else -> {This1, {ok, Else}}
end,
{This2, #protocol_set_begin{etype = compact_to_typeid(SizeAndType band 16#f),
size = Size}};
read(This, set_end) -> {This, ok};
read(This0, field_stop) ->
{This1, {ok, ?tType_STOP}} = read(This0, ubyte),
{This1, ok};
%%
read(This0 = #t_compact{read_value = ?CBOOL_NONE}, bool) ->
{This1, {ok, Byte}} = read(This0, ubyte),
{This1, {ok, cbool_to_bool(Byte)}};
read(This0 = #t_compact{read_value = Bool}, bool) ->
{This0#t_compact{read_value = ?CBOOL_NONE}, {ok, Bool}};
read(This0, ubyte) ->
{This1, {ok, <<Val:8/integer-unsigned-big, _/binary>>}} = read_data(This0, 1),
{This1, {ok, Val}};
read(This0, byte) ->
{This1, Bytes} = read_data(This0, 1),
case Bytes of
{ok, <<Val:8/integer-signed-big, _/binary>>} -> {This1, {ok, Val}};
Else -> {This1, Else}
end;
read(This0, i16) ->
{This1, {ok, Zigzag}} = read_varint(This0, 0, 0),
{This1, {ok, from_zigzag(Zigzag)}};
read(This0, ui32) -> read_varint(This0, 0, 0);
read(This0, i32) ->
{This1, {ok, Zigzag}} = read_varint(This0, 0, 0),
{This1, {ok, from_zigzag(Zigzag)}};
read(This0, i64) ->
{This1, {ok, Zigzag}} = read_varint(This0, 0, 0),
{This1, {ok, from_zigzag(Zigzag)}};
read(This0, double) ->
{This1, Bytes} = read_data(This0, 8),
case Bytes of
{ok, <<Val:64/float-signed-little, _/binary>>} -> {This1, {ok, Val}};
Else -> {This1, Else}
end;
% returns a binary directly, call binary_to_list if necessary
read(This0, string) ->
{This1, {ok, Sz}} = read(This0, ui32),
read_data(This1, Sz).
-spec read_data(#t_compact{}, non_neg_integer()) ->
{#t_compact{}, {ok, binary()} | {error, _Reason}}.
read_data(This, 0) -> {This, {ok, <<>>}};
read_data(This = #t_compact{transport = Trans}, Len) when is_integer(Len) andalso Len > 0 ->
{NewTransport, Result} = thrift_transport:read(Trans, Len),
{This#t_compact{transport = NewTransport}, Result}.
%%%% FACTORY GENERATION %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%% returns a (fun() -> thrift_protocol())
new_protocol_factory(TransportFactory, _Options) ->
F = fun() ->
case TransportFactory() of
{ok, Transport} ->
thrift_compact_protocol:new(
Transport,
[]);
{error, Error} ->
{error, Error}
end
end,
{ok, F}.

View file

@ -0,0 +1,123 @@
%%
%% Licensed to the Apache Software Foundation (ASF) under one
%% or more contributor license agreements. See the NOTICE file
%% distributed with this work for additional information
%% regarding copyright ownership. The ASF licenses this file
%% to you under the Apache License, Version 2.0 (the
%% "License"); you may not use this file except in compliance
%% with the License. You may obtain a copy of the License at
%%
%% http://www.apache.org/licenses/LICENSE-2.0
%%
%% Unless required by applicable law or agreed to in writing,
%% software distributed under the License is distributed on an
%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
%% KIND, either express or implied. See the License for the
%% specific language governing permissions and limitations
%% under the License.
%%
%%% Todo: this might be better off as a gen_server type of transport
%%% that handles stuff like group commit, similar to TFileTransport
%%% in cpp land
-module(thrift_disk_log_transport).
-behaviour(thrift_transport).
%% API
-export([new/2, new_transport_factory/2, new_transport_factory/3]).
%% thrift_transport callbacks
-export([read/2, write/2, force_flush/1, flush/1, close/1]).
%% state
-record(dl_transport, {log,
close_on_close = false,
sync_every = infinity,
sync_tref}).
-type state() :: #dl_transport{}.
-include("thrift_transport_behaviour.hrl").
%% Create a transport attached to an already open log.
%% If you'd like this transport to close the disk_log using disk_log:lclose()
%% when the transport is closed, pass a {close_on_close, true} tuple in the
%% Opts list.
new(LogName, Opts) when is_atom(LogName), is_list(Opts) ->
State = parse_opts(Opts, #dl_transport{log = LogName}),
State2 =
case State#dl_transport.sync_every of
N when is_integer(N), N > 0 ->
{ok, TRef} = timer:apply_interval(N, ?MODULE, force_flush, [State]),
State#dl_transport{sync_tref = TRef};
_ -> State
end,
thrift_transport:new(?MODULE, State2).
parse_opts([], State) ->
State;
parse_opts([{close_on_close, Bool} | Rest], State) when is_boolean(Bool) ->
parse_opts(Rest, State#dl_transport{close_on_close = Bool});
parse_opts([{sync_every, Int} | Rest], State) when is_integer(Int), Int > 0 ->
parse_opts(Rest, State#dl_transport{sync_every = Int}).
%%%% TRANSPORT IMPLENTATION %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%% disk_log_transport is write-only
read(State, _Len) ->
{State, {error, no_read_from_disk_log}}.
write(This = #dl_transport{log = Log}, Data) ->
{This, disk_log:balog(Log, erlang:iolist_to_binary(Data))}.
force_flush(#dl_transport{log = Log}) ->
error_logger:info_msg("~p syncing~n", [?MODULE]),
disk_log:sync(Log).
flush(This = #dl_transport{log = Log, sync_every = SE}) ->
case SE of
undefined -> % no time-based sync
disk_log:sync(Log);
_Else -> % sync will happen automagically
ok
end,
{This, ok}.
%% On close, close the underlying log if we're configured to do so.
close(This = #dl_transport{close_on_close = false}) ->
{This, ok};
close(This = #dl_transport{log = Log}) ->
{This, disk_log:lclose(Log)}.
%%%% FACTORY GENERATION %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
new_transport_factory(Name, ExtraLogOpts) ->
new_transport_factory(Name, ExtraLogOpts, [{close_on_close, true},
{sync_every, 500}]).
new_transport_factory(Name, ExtraLogOpts, TransportOpts) ->
F = fun() -> factory_impl(Name, ExtraLogOpts, TransportOpts) end,
{ok, F}.
factory_impl(Name, ExtraLogOpts, TransportOpts) ->
LogOpts = [{name, Name},
{format, external},
{type, wrap} |
ExtraLogOpts],
Log =
case disk_log:open(LogOpts) of
{ok, LogS} ->
LogS;
{repaired, LogS, Info1, Info2} ->
error_logger:info_msg("Disk log ~p repaired: ~p, ~p~n", [LogS, Info1, Info2]),
LogS
end,
new(Log, TransportOpts).

View file

@ -0,0 +1,115 @@
%%
%% Licensed to the Apache Software Foundation (ASF) under one
%% or more contributor license agreements. See the NOTICE file
%% distributed with this work for additional information
%% regarding copyright ownership. The ASF licenses this file
%% to you under the Apache License, Version 2.0 (the
%% "License"); you may not use this file except in compliance
%% with the License. You may obtain a copy of the License at
%%
%% http://www.apache.org/licenses/LICENSE-2.0
%%
%% Unless required by applicable law or agreed to in writing,
%% software distributed under the License is distributed on an
%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
%% KIND, either express or implied. See the License for the
%% specific language governing permissions and limitations
%% under the License.
%%
-module(thrift_file_transport).
-behaviour(thrift_transport).
%% constructors
-export([new/1, new/2]).
%% protocol callbacks
-export([read/2, read_exact/2, write/2, flush/1, close/1]).
%% legacy api
-export([new_reader/1]).
-record(t_file, {
device,
should_close = true,
mode = write
}).
-type state() :: #t_file{}.
-spec new(Device::file:io_device()) ->
thrift_transport:t_transport().
new(Device) -> new(Device, []).
-spec new(Device::file:io_device(), Opts::list()) ->
thrift_transport:t_transport().
%% Device should be opened in raw and binary mode.
new(Device, Opts) when is_list(Opts) ->
State = parse_opts(Opts, #t_file{device = Device}),
thrift_transport:new(?MODULE, State).
parse_opts([{should_close, Bool}|Rest], State)
when is_boolean(Bool) ->
parse_opts(Rest, State#t_file{should_close = Bool});
parse_opts([{mode, Mode}|Rest], State)
when Mode =:= write; Mode =:= read ->
parse_opts(Rest, State#t_file{mode = Mode});
parse_opts([], State) ->
State.
-include("thrift_transport_behaviour.hrl").
read(State = #t_file{device = Device, mode = read}, Len)
when is_integer(Len), Len >= 0 ->
case file:read(Device, Len) of
eof -> {State, {error, eof}};
{ok, Result} -> {State, {ok, iolist_to_binary(Result)}}
end;
read(State, _) ->
{State, {error, write_mode}}.
read_exact(State = #t_file{device = Device, mode = read}, Len)
when is_integer(Len), Len >= 0 ->
case file:read(Device, Len) of
eof -> {State, {error, eof}};
{ok, Result} ->
case iolist_size(Result) of
X when X < Len -> {State, {error, eof}};
_ -> {State, {ok, iolist_to_binary(Result)}}
end
end;
read_exact(State, _) ->
{State, {error, write_mode}}.
write(State = #t_file{device = Device, mode = write}, Data) ->
{State, file:write(Device, Data)};
write(State, _) ->
{State, {error, read_mode}}.
flush(State = #t_file{device = Device, mode = write}) ->
{State, file:sync(Device)}.
close(State = #t_file{device = Device, should_close = SC}) ->
case SC of
true -> {State, file:close(Device)};
false -> {State, ok}
end.
%% legacy api. left for compatibility
new_reader(Filename) ->
case file:open(Filename, [read, binary, {read_ahead, 1024*1024}]) of
{ok, IODevice} -> new(IODevice, [{should_close, true}, {mode, read}]);
Error -> Error
end.

View file

@ -0,0 +1,125 @@
%%
%% Licensed to the Apache Software Foundation (ASF) under one
%% or more contributor license agreements. See the NOTICE file
%% distributed with this work for additional information
%% regarding copyright ownership. The ASF licenses this file
%% to you under the Apache License, Version 2.0 (the
%% "License"); you may not use this file except in compliance
%% with the License. You may obtain a copy of the License at
%%
%% http://www.apache.org/licenses/LICENSE-2.0
%%
%% Unless required by applicable law or agreed to in writing,
%% software distributed under the License is distributed on an
%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
%% KIND, either express or implied. See the License for the
%% specific language governing permissions and limitations
%% under the License.
%%
-module(thrift_framed_transport).
-behaviour(thrift_transport).
%% constructor
-export([new/1]).
%% protocol callbacks
-export([read/2, read_exact/2, write/2, flush/1, close/1]).
-record(t_framed, {
wrapped,
read_buffer,
write_buffer
}).
-type state() :: #t_framed{}.
-spec new(Transport::thrift_transport:t_transport()) ->
thrift_transport:t_transport().
new(Wrapped) ->
State = #t_framed{
wrapped = Wrapped,
read_buffer = [],
write_buffer = []
},
thrift_transport:new(?MODULE, State).
-include("thrift_transport_behaviour.hrl").
read(State = #t_framed{wrapped = Wrapped, read_buffer = Buffer}, Len)
when is_integer(Len), Len >= 0 ->
Binary = iolist_to_binary(Buffer),
case Binary of
<<>> when Len > 0 ->
case next_frame(Wrapped) of
{NewState, {ok, Frame}} ->
NewBinary = iolist_to_binary([Binary, Frame]),
Give = min(iolist_size(NewBinary), Len),
{Result, Remaining} = split_binary(NewBinary, Give),
{State#t_framed{wrapped = NewState, read_buffer = Remaining}, {ok, Result}};
Error -> Error
end;
%% read of zero bytes
<<>> -> {State, {ok, <<>>}};
%% read buffer is nonempty
_ ->
Give = min(iolist_size(Binary), Len),
{Result, Remaining} = split_binary(Binary, Give),
{State#t_framed{read_buffer = Remaining}, {ok, Result}}
end.
read_exact(State = #t_framed{wrapped = Wrapped, read_buffer = Buffer}, Len)
when is_integer(Len), Len >= 0 ->
Binary = iolist_to_binary(Buffer),
case iolist_size(Binary) of
%% read buffer is larger than requested read size
X when X >= Len ->
{Result, Remaining} = split_binary(Binary, Len),
{State#t_framed{read_buffer = Remaining}, {ok, Result}};
%% read buffer is insufficient for requested read size
_ ->
case next_frame(Wrapped) of
{NewState, {ok, Frame}} ->
read_exact(
State#t_framed{wrapped = NewState, read_buffer = [Buffer, Frame]},
Len
);
{NewState, Error} ->
{State#t_framed{wrapped = NewState}, Error}
end
end.
next_frame(Transport) ->
case thrift_transport:read_exact(Transport, 4) of
{NewState, {ok, <<FrameLength:32/integer-signed-big>>}} ->
thrift_transport:read_exact(NewState, FrameLength);
Error -> Error
end.
write(State = #t_framed{write_buffer = Buffer}, Data) ->
{State#t_framed{write_buffer = [Buffer, Data]}, ok}.
flush(State = #t_framed{write_buffer = Buffer, wrapped = Wrapped}) ->
case iolist_size(Buffer) of
%% if write buffer is empty, do nothing
0 -> {State, ok};
FrameLen ->
Data = [<<FrameLen:32/integer-signed-big>>, Buffer],
{Written, Response} = thrift_transport:write(Wrapped, Data),
{Flushed, ok} = thrift_transport:flush(Written),
{State#t_framed{wrapped = Flushed, write_buffer = []}, Response}
end.
close(State = #t_framed{wrapped = Wrapped}) ->
{Closed, Result} = thrift_transport:close(Wrapped),
{State#t_framed{wrapped = Closed}, Result}.

View file

@ -0,0 +1,116 @@
%%
%% Licensed to the Apache Software Foundation (ASF) under one
%% or more contributor license agreements. See the NOTICE file
%% distributed with this work for additional information
%% regarding copyright ownership. The ASF licenses this file
%% to you under the Apache License, Version 2.0 (the
%% "License"); you may not use this file except in compliance
%% with the License. You may obtain a copy of the License at
%%
%% http://www.apache.org/licenses/LICENSE-2.0
%%
%% Unless required by applicable law or agreed to in writing,
%% software distributed under the License is distributed on an
%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
%% KIND, either express or implied. See the License for the
%% specific language governing permissions and limitations
%% under the License.
%%
-module(thrift_http_transport).
-behaviour(thrift_transport).
%% API
-export([new/2, new/3]).
%% thrift_transport callbacks
-export([write/2, read/2, flush/1, close/1]).
-record(http_transport, {host, % string()
path, % string()
read_buffer, % iolist()
write_buffer, % iolist()
http_options, % see http(3)
extra_headers % [{str(), str()}, ...]
}).
-type state() :: #http_transport{}.
-include("thrift_transport_behaviour.hrl").
new(Host, Path) ->
new(Host, Path, _Options = []).
%%--------------------------------------------------------------------
%% Options include:
%% {http_options, HttpOptions} = See http(3)
%% {extra_headers, ExtraHeaders} = List of extra HTTP headers
%%--------------------------------------------------------------------
new(Host, Path, Options) ->
State1 = #http_transport{host = Host,
path = Path,
read_buffer = [],
write_buffer = [],
http_options = [],
extra_headers = []},
ApplyOption =
fun
({http_options, HttpOpts}, State = #http_transport{}) ->
State#http_transport{http_options = HttpOpts};
({extra_headers, ExtraHeaders}, State = #http_transport{}) ->
State#http_transport{extra_headers = ExtraHeaders};
(Other, #http_transport{}) ->
{invalid_option, Other};
(_, Error) ->
Error
end,
case lists:foldl(ApplyOption, State1, Options) of
State2 = #http_transport{} ->
thrift_transport:new(?MODULE, State2);
Else ->
{error, Else}
end.
%% Writes data into the buffer
write(State = #http_transport{write_buffer = WBuf}, Data) ->
{State#http_transport{write_buffer = [WBuf, Data]}, ok}.
%% Flushes the buffer, making a request
flush(State = #http_transport{host = Host,
path = Path,
read_buffer = Rbuf,
write_buffer = Wbuf,
http_options = HttpOptions,
extra_headers = ExtraHeaders}) ->
case iolist_to_binary(Wbuf) of
<<>> ->
%% Don't bother flushing empty buffers.
{State, ok};
WBinary ->
{ok, {{_Version, 200, _ReasonPhrase}, _Headers, Body}} =
httpc:request(post,
{"http://" ++ Host ++ Path,
[{"User-Agent", "Erlang/thrift_http_transport"} | ExtraHeaders],
"application/x-thrift",
WBinary},
HttpOptions,
[{body_format, binary}]),
State1 = State#http_transport{read_buffer = [Rbuf, Body],
write_buffer = []},
{State1, ok}
end.
close(State) ->
{State, ok}.
read(State = #http_transport{read_buffer = RBuf}, Len) when is_integer(Len) ->
%% Pull off Give bytes, return them to the user, leave the rest in the buffer.
Give = min(iolist_size(RBuf), Len),
case iolist_to_binary(RBuf) of
<<Data:Give/binary, RBuf1/binary>> ->
Response = {ok, Data},
State1 = State#http_transport{read_buffer=RBuf1},
{State1, Response};
_ ->
{State, {error, 'EOF'}}
end.

View file

@ -0,0 +1,419 @@
%% Licensed to the Apache Software Foundation (ASF) under one
%% or more contributor license agreements. See the NOTICE file
%% distributed with this work for additional information
%% regarding copyright ownership. The ASF licenses this file
%% to you under the Apache License, Version 2.0 (the
%% "License"); you may not use this file except in compliance
%% with the License. You may obtain a copy of the License at
%%
%% http://www.apache.org/licenses/LICENSE-2.0
%%
%% Unless required by applicable law or agreed to in writing,
%% software distributed under the License is distributed on an
%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
%% KIND, either express or implied. See the License for the
%% specific language governing permissions and limitations
%% under the License.
%%
%% The json parser implementation was created by
%% alisdair sullivan <alisdair@hartbrake.com> based on
%% the jsx json library
-module(thrift_json_parser).
-export([parser/0, handle_event/2]).
-record(config, {strict_utf8 = false :: boolean()}).
parser() -> fun(JSON) -> start(JSON, {?MODULE, []}, [], #config{}) end.
handle_event(Event, {Handler, State}, _Config) -> {Handler, Handler:handle_event(Event, State)}.
handle_event(end_json, State) -> lists:reverse([end_json] ++ State);
handle_event(Event, State) -> [Event] ++ State.
%% whitespace
-define(space, 16#20).
-define(tab, 16#09).
-define(cr, 16#0D).
-define(newline, 16#0A).
%% object delimiters
-define(start_object, 16#7B).
-define(end_object, 16#7D).
%% array delimiters
-define(start_array, 16#5B).
-define(end_array, 16#5D).
%% kv seperator
-define(comma, 16#2C).
-define(doublequote, 16#22).
-define(singlequote, 16#27).
-define(colon, 16#3A).
%% string escape sequences
-define(rsolidus, 16#5C).
-define(solidus, 16#2F).
%% math
-define(zero, 16#30).
-define(decimalpoint, 16#2E).
-define(negative, 16#2D).
-define(positive, 16#2B).
%% comments
-define(star, 16#2A).
%% some useful guards
-define(is_hex(Symbol),
(Symbol >= $a andalso Symbol =< $f) orelse
(Symbol >= $A andalso Symbol =< $F) orelse
(Symbol >= $0 andalso Symbol =< $9)
).
-define(is_nonzero(Symbol),
Symbol >= $1 andalso Symbol =< $9
).
-define(is_whitespace(Symbol),
Symbol =:= ?space; Symbol =:= ?tab; Symbol =:= ?cr; Symbol =:= ?newline
).
%% lists are benchmarked to be faster (tho higher in memory usage) than binaries
new_seq() -> [].
new_seq(C) -> [C].
acc_seq(Seq, C) when is_list(C) -> lists:reverse(C) ++ Seq;
acc_seq(Seq, C) -> [C] ++ Seq.
end_seq(Seq) -> unicode:characters_to_binary(lists:reverse(Seq)).
end_seq(Seq, _) -> end_seq(Seq).
start(<<16#ef, 16#bb, 16#bf, Rest/binary>>, Handler, Stack, Config) ->
value(Rest, Handler, Stack, Config);
start(Bin, Handler, Stack, Config) ->
value(Bin, Handler, Stack, Config).
value(<<?doublequote, Rest/binary>>, Handler, Stack, Config) ->
string(Rest, Handler, new_seq(), Stack, Config);
value(<<$t, Rest/binary>>, Handler, Stack, Config) ->
true(Rest, Handler, Stack, Config);
value(<<$f, Rest/binary>>, Handler, Stack, Config) ->
false(Rest, Handler, Stack, Config);
value(<<$n, Rest/binary>>, Handler, Stack, Config) ->
null(Rest, Handler, Stack, Config);
value(<<?negative, Rest/binary>>, Handler, Stack, Config) ->
negative(Rest, Handler, new_seq($-), Stack, Config);
value(<<?zero, Rest/binary>>, Handler, Stack, Config) ->
zero(Rest, Handler, new_seq($0), Stack, Config);
value(<<S, Rest/binary>>, Handler, Stack, Config) when ?is_nonzero(S) ->
integer(Rest, Handler, new_seq(S), Stack, Config);
value(<<?start_object, Rest/binary>>, Handler, Stack, Config) ->
object(Rest, handle_event(start_object, Handler, Config), [key|Stack], Config);
value(<<?start_array, Rest/binary>>, Handler, Stack, Config) ->
array(Rest, handle_event(start_array, Handler, Config), [array|Stack], Config);
value(<<S, Rest/binary>>, Handler, Stack, Config) when ?is_whitespace(S) ->
value(Rest, Handler, Stack, Config);
value(_Bin, _Handler, _Stack, _Config) ->
erlang:error(badarg).
object(<<?doublequote, Rest/binary>>, Handler, Stack, Config) ->
string(Rest, Handler, new_seq(), Stack, Config);
object(<<?end_object, Rest/binary>>, Handler, [key|Stack], Config) ->
maybe_done(Rest, handle_event(end_object, Handler, Config), Stack, Config);
object(<<S, Rest/binary>>, Handler, Stack, Config) when ?is_whitespace(S) ->
object(Rest, Handler, Stack, Config);
object(_Bin, _Handler, _Stack, _Config) ->
erlang:error(badarg).
array(<<?end_array, Rest/binary>>, Handler, [array|Stack], Config) ->
maybe_done(Rest, handle_event(end_array, Handler, Config), Stack, Config);
array(<<S, Rest/binary>>, Handler, Stack, Config) when ?is_whitespace(S) ->
array(Rest, Handler, Stack, Config);
array(Bin, Handler, Stack, Config) ->
value(Bin, Handler, Stack, Config).
colon(<<?colon, Rest/binary>>, Handler, [key|Stack], Config) ->
value(Rest, Handler, [object|Stack], Config);
colon(<<S, Rest/binary>>, Handler, Stack, Config) when ?is_whitespace(S) ->
colon(Rest, Handler, Stack, Config);
colon(_Bin, _Handler, _Stack, _Config) ->
erlang:error(badarg).
key(<<?doublequote, Rest/binary>>, Handler, Stack, Config) ->
string(Rest, Handler, new_seq(), Stack, Config);
key(<<S, Rest/binary>>, Handler, Stack, Config) when ?is_whitespace(S) ->
key(Rest, Handler, Stack, Config);
key(_Bin, _Handler, _Stack, _Config) ->
erlang:error(badarg).
%% note that if you encounter an error from string and you can't find the clause that
%% caused it here, it might be in unescape below
string(<<?doublequote, Rest/binary>>, Handler, Acc, Stack, Config) ->
doublequote(Rest, Handler, Acc, Stack, Config);
string(<<?solidus, Rest/binary>>, Handler, Acc, Stack, Config) ->
string(Rest, Handler, acc_seq(Acc, ?solidus), Stack, Config);
string(<<?rsolidus/utf8, Rest/binary>>, Handler, Acc, Stack, Config) ->
unescape(Rest, Handler, Acc, Stack, Config);
string(<<X/utf8, Rest/binary>>, Handler, Acc, Stack, Config) when X >= 16#20, X < 16#2028 ->
string(Rest, Handler, acc_seq(Acc, X), Stack, Config);
string(<<X/utf8, Rest/binary>>, Handler, Acc, Stack, Config) when X == 16#2028; X == 16#2029 ->
string(Rest, Handler, acc_seq(Acc, X), Stack, Config);
string(<<X/utf8, Rest/binary>>, Handler, Acc, Stack, Config) when X > 16#2029, X < 16#d800 ->
string(Rest, Handler, acc_seq(Acc, X), Stack, Config);
string(<<X/utf8, Rest/binary>>, Handler, Acc, Stack, Config) when X > 16#dfff, X < 16#fdd0 ->
string(Rest, Handler, acc_seq(Acc, X), Stack, Config);
string(<<X/utf8, Rest/binary>>, Handler, Acc, Stack, Config) when X > 16#fdef, X < 16#fffe ->
string(Rest, Handler, acc_seq(Acc, X), Stack, Config);
string(<<X/utf8, Rest/binary>>, Handler, Acc, Stack, Config) when X >= 16#10000, X < 16#1fffe ->
string(Rest, Handler, acc_seq(Acc, X), Stack, Config);
string(<<X/utf8, Rest/binary>>, Handler, Acc, Stack, Config) when X >= 16#20000, X < 16#2fffe ->
string(Rest, Handler, acc_seq(Acc, X), Stack, Config);
string(<<X/utf8, Rest/binary>>, Handler, Acc, Stack, Config) when X >= 16#30000, X < 16#3fffe ->
string(Rest, Handler, acc_seq(Acc, X), Stack, Config);
string(<<X/utf8, Rest/binary>>, Handler, Acc, Stack, Config) when X >= 16#40000, X < 16#4fffe ->
string(Rest, Handler, acc_seq(Acc, X), Stack, Config);
string(<<X/utf8, Rest/binary>>, Handler, Acc, Stack, Config) when X >= 16#50000, X < 16#5fffe ->
string(Rest, Handler, acc_seq(Acc, X), Stack, Config);
string(<<X/utf8, Rest/binary>>, Handler, Acc, Stack, Config) when X >= 16#60000, X < 16#6fffe ->
string(Rest, Handler, acc_seq(Acc, X), Stack, Config);
string(<<X/utf8, Rest/binary>>, Handler, Acc, Stack, Config) when X >= 16#70000, X < 16#7fffe ->
string(Rest, Handler, acc_seq(Acc, X), Stack, Config);
string(<<X/utf8, Rest/binary>>, Handler, Acc, Stack, Config) when X >= 16#80000, X < 16#8fffe ->
string(Rest, Handler, acc_seq(Acc, X), Stack, Config);
string(<<X/utf8, Rest/binary>>, Handler, Acc, Stack, Config) when X >= 16#90000, X < 16#9fffe ->
string(Rest, Handler, acc_seq(Acc, X), Stack, Config);
string(<<X/utf8, Rest/binary>>, Handler, Acc, Stack, Config) when X >= 16#a0000, X < 16#afffe ->
string(Rest, Handler, acc_seq(Acc, X), Stack, Config);
string(<<X/utf8, Rest/binary>>, Handler, Acc, Stack, Config) when X >= 16#b0000, X < 16#bfffe ->
string(Rest, Handler, acc_seq(Acc, X), Stack, Config);
string(<<X/utf8, Rest/binary>>, Handler, Acc, Stack, Config) when X >= 16#c0000, X < 16#cfffe ->
string(Rest, Handler, acc_seq(Acc, X), Stack, Config);
string(<<X/utf8, Rest/binary>>, Handler, Acc, Stack, Config) when X >= 16#d0000, X < 16#dfffe ->
string(Rest, Handler, acc_seq(Acc, X), Stack, Config);
string(<<X/utf8, Rest/binary>>, Handler, Acc, Stack, Config) when X >= 16#e0000, X < 16#efffe ->
string(Rest, Handler, acc_seq(Acc, X), Stack, Config);
string(<<X/utf8, Rest/binary>>, Handler, Acc, Stack, Config) when X >= 16#f0000, X < 16#ffffe ->
string(Rest, Handler, acc_seq(Acc, X), Stack, Config);
string(<<X/utf8, Rest/binary>>, Handler, Acc, Stack, Config) when X >= 16#100000, X < 16#10fffe ->
string(Rest, Handler, acc_seq(Acc, X), Stack, Config);
%% surrogates
string(<<237, X, _, Rest/binary>>, Handler, Acc, Stack, Config=#config{strict_utf8=false})
when X >= 160 ->
string(Rest, Handler, acc_seq(Acc, 16#fffd), Stack, Config);
%% u+xfffe, u+xffff, control codes and other noncharacters
string(<<_/utf8, Rest/binary>>, Handler, Acc, Stack, Config=#config{strict_utf8=false}) ->
string(Rest, Handler, acc_seq(Acc, 16#fffd), Stack, Config);
%% u+fffe and u+ffff for R14BXX (subsequent runtimes will happily match the
%% preceding clause
string(<<239, 191, X, Rest/binary>>, Handler, Acc, Stack, Config=#config{strict_utf8=false})
when X == 190; X == 191 ->
string(Rest, Handler, acc_seq(Acc, 16#fffd), Stack, Config);
%% overlong encodings and missing continuations of a 2 byte sequence
string(<<X, Rest/binary>>, Handler, Acc, Stack, Config=#config{strict_utf8=false})
when X >= 192, X =< 223 ->
strip_continuations(Rest, Handler, Acc, Stack, Config, 1);
%% overlong encodings and missing continuations of a 3 byte sequence
string(<<X, Rest/binary>>, Handler, Acc, Stack, Config=#config{strict_utf8=false})
when X >= 224, X =< 239 ->
strip_continuations(Rest, Handler, Acc, Stack, Config, 2);
%% overlong encodings and missing continuations of a 4 byte sequence
string(<<X, Rest/binary>>, Handler, Acc, Stack, Config=#config{strict_utf8=false})
when X >= 240, X =< 247 ->
strip_continuations(Rest, Handler, Acc, Stack, Config, 3);
%% incompletes and unexpected bytes, including orphan continuations
string(<<_, Rest/binary>>, Handler, Acc, Stack, Config=#config{strict_utf8=false}) ->
string(Rest, Handler, acc_seq(Acc, 16#fffd), Stack, Config);
string(_Bin, _Handler, _Acc, _Stack, _Config) ->
erlang:error(badarg).
doublequote(Rest, Handler, Acc, [key|_] = Stack, Config) ->
colon(Rest, handle_event({key, end_seq(Acc, Config)}, Handler, Config), Stack, Config);
doublequote(Rest, Handler, Acc, Stack, Config) ->
maybe_done(Rest, handle_event({string, end_seq(Acc, Config)}, Handler, Config), Stack, Config).
%% strips continuation bytes after bad utf bytes, guards against both too short
%% and overlong sequences. N is the maximum number of bytes to strip
strip_continuations(<<Rest/binary>>, Handler, Acc, Stack, Config, 0) ->
string(Rest, Handler, acc_seq(Acc, 16#fffd), Stack, Config);
strip_continuations(<<X, Rest/binary>>, Handler, Acc, Stack, Config, N) when X >= 128, X =< 191 ->
strip_continuations(Rest, Handler, Acc, Stack, Config, N - 1);
%% not a continuation byte, insert a replacement character for sequence thus
%% far and dispatch back to string
strip_continuations(<<Rest/binary>>, Handler, Acc, Stack, Config, _) ->
string(Rest, Handler, acc_seq(Acc, 16#fffd), Stack, Config).
%% this all gets really gross and should probably eventually be folded into
%% but for now it fakes being part of string on incompletes and errors
unescape(<<$b, Rest/binary>>, Handler, Acc, Stack, Config) ->
string(Rest, Handler, acc_seq(Acc, $\b), Stack, Config);
unescape(<<$f, Rest/binary>>, Handler, Acc, Stack, Config) ->
string(Rest, Handler, acc_seq(Acc, $\f), Stack, Config);
unescape(<<$n, Rest/binary>>, Handler, Acc, Stack, Config) ->
string(Rest, Handler, acc_seq(Acc, $\n), Stack, Config);
unescape(<<$r, Rest/binary>>, Handler, Acc, Stack, Config) ->
string(Rest, Handler, acc_seq(Acc, $\r), Stack, Config);
unescape(<<$t, Rest/binary>>, Handler, Acc, Stack, Config) ->
string(Rest, Handler, acc_seq(Acc, $\t), Stack, Config);
unescape(<<?doublequote, Rest/binary>>, Handler, Acc, Stack, Config) ->
string(Rest, Handler, acc_seq(Acc, $\"), Stack, Config);
unescape(<<?rsolidus, Rest/binary>>, Handler, Acc, Stack, Config) ->
string(Rest, Handler, acc_seq(Acc, $\\), Stack, Config);
unescape(<<?solidus, Rest/binary>>, Handler, Acc, Stack, Config) ->
string(Rest, Handler, acc_seq(Acc, $/), Stack, Config);
unescape(<<$u, $d, A, B, C, ?rsolidus, $u, $d, X, Y, Z, Rest/binary>>, Handler, Acc, Stack, Config)
when (A == $8 orelse A == $9 orelse A == $a orelse A == $b),
(X == $c orelse X == $d orelse X == $e orelse X == $f),
?is_hex(B), ?is_hex(C), ?is_hex(Y), ?is_hex(Z)
->
High = erlang:list_to_integer([$d, A, B, C], 16),
Low = erlang:list_to_integer([$d, X, Y, Z], 16),
Codepoint = (High - 16#d800) * 16#400 + (Low - 16#dc00) + 16#10000,
string(Rest, Handler, acc_seq(Acc, Codepoint), Stack, Config);
unescape(<<$u, $d, A, B, C, ?rsolidus, $u, W, X, Y, Z, Rest/binary>>, Handler, Acc, Stack, Config)
when (A == $8 orelse A == $9 orelse A == $a orelse A == $b),
?is_hex(B), ?is_hex(C), ?is_hex(W), ?is_hex(X), ?is_hex(Y), ?is_hex(Z)
->
string(Rest, Handler, acc_seq(Acc, [16#fffd, 16#fffd]), Stack, Config);
unescape(<<$u, A, B, C, D, Rest/binary>>, Handler, Acc, Stack, Config)
when ?is_hex(A), ?is_hex(B), ?is_hex(C), ?is_hex(D) ->
case erlang:list_to_integer([A, B, C, D], 16) of
Codepoint when Codepoint < 16#d800; Codepoint > 16#dfff ->
string(Rest, Handler, acc_seq(Acc, Codepoint), Stack, Config);
_ ->
string(Rest, Handler, acc_seq(Acc, 16#fffd), Stack, Config)
end;
unescape(_Bin, _Handler, _Acc, _Stack, _Config) ->
erlang:error(badarg).
%% like in strings, there's some pseudo states in here that will never
%% show up in errors or incompletes. some show up in value, some show
%% up in integer, decimal or exp
negative(<<$0, Rest/binary>>, Handler, Acc, Stack, Config) ->
zero(Rest, Handler, acc_seq(Acc, $0), Stack, Config);
negative(<<S, Rest/binary>>, Handler, Acc, Stack, Config) when ?is_nonzero(S) ->
integer(Rest, Handler, acc_seq(Acc, S), Stack, Config);
negative(_Bin, _Handler, _Acc, _Stack, _Config) ->
erlang:error(badarg).
zero(<<?decimalpoint, Rest/binary>>, Handler, Acc, Stack, Config) ->
decimal(Rest, Handler, acc_seq(Acc, ?decimalpoint), Stack, Config);
zero(<<S, Rest/binary>>, Handler, Acc, Stack, Config) when S =:= $e; S =:= $E ->
e(Rest, Handler, acc_seq(Acc, ".0e"), Stack, Config);
zero(Bin, Handler, Acc, Stack, Config) ->
finish_number(Bin, Handler, {zero, Acc}, Stack, Config).
integer(<<S, Rest/binary>>, Handler, Acc, Stack, Config) when S =:= ?zero; ?is_nonzero(S) ->
integer(Rest, Handler, acc_seq(Acc, S), Stack, Config);
integer(<<?decimalpoint, Rest/binary>>, Handler, Acc, Stack, Config) ->
initialdecimal(Rest, Handler, acc_seq(Acc, ?decimalpoint), Stack, Config);
integer(<<S, Rest/binary>>, Handler, Acc, Stack, Config) when S =:= $e; S =:= $E ->
e(Rest, Handler, acc_seq(Acc, ".0e"), Stack, Config);
integer(Bin, Handler, Acc, Stack, Config) ->
finish_number(Bin, Handler, {integer, Acc}, Stack, Config).
initialdecimal(<<S, Rest/binary>>, Handler, Acc, Stack, Config) when S =:= ?zero; ?is_nonzero(S) ->
decimal(Rest, Handler, acc_seq(Acc, S), Stack, Config);
initialdecimal(_Bin, _Handler, _Acc, _Stack, _Config) ->
erlang:error(badarg).
decimal(<<S, Rest/binary>>, Handler, Acc, Stack, Config) when S =:= ?zero; ?is_nonzero(S) ->
decimal(Rest, Handler, acc_seq(Acc, S), Stack, Config);
decimal(<<S, Rest/binary>>, Handler, Acc, Stack, Config) when S =:= $e; S =:= $E ->
e(Rest, Handler, acc_seq(Acc, $e), Stack, Config);
decimal(Bin, Handler, Acc, Stack, Config) ->
finish_number(Bin, Handler, {decimal, Acc}, Stack, Config).
e(<<S, Rest/binary>>, Handler, Acc, Stack, Config) when S =:= ?zero; ?is_nonzero(S) ->
exp(Rest, Handler, acc_seq(Acc, S), Stack, Config);
e(<<Sign, Rest/binary>>, Handler, Acc, Stack, Config) when Sign =:= ?positive; Sign =:= ?negative ->
ex(Rest, Handler, acc_seq(Acc, Sign), Stack, Config);
e(_Bin, _Handler, _Acc, _Stack, _Config) ->
erlang:error(badarg).
ex(<<S, Rest/binary>>, Handler, Acc, Stack, Config) when S =:= ?zero; ?is_nonzero(S) ->
exp(Rest, Handler, acc_seq(Acc, S), Stack, Config);
ex(_Bin, _Handler, _Acc, _Stack, _Config) ->
erlang:error(badarg).
exp(<<S, Rest/binary>>, Handler, Acc, Stack, Config) when S =:= ?zero; ?is_nonzero(S) ->
exp(Rest, Handler, acc_seq(Acc, S), Stack, Config);
exp(Bin, Handler, Acc, Stack, Config) ->
finish_number(Bin, Handler, {exp, Acc}, Stack, Config).
finish_number(Rest, Handler, Acc, [], Config) ->
maybe_done(Rest, handle_event(format_number(Acc), Handler, Config), [], Config);
finish_number(Rest, Handler, Acc, Stack, Config) ->
maybe_done(Rest, handle_event(format_number(Acc), Handler, Config), Stack, Config).
format_number({zero, Acc}) -> {integer, list_to_integer(lists:reverse(Acc))};
format_number({integer, Acc}) -> {integer, list_to_integer(lists:reverse(Acc))};
format_number({decimal, Acc}) -> {float, list_to_float(lists:reverse(Acc))};
format_number({exp, Acc}) -> {float, list_to_float(lists:reverse(Acc))}.
true(<<$r, $u, $e, Rest/binary>>, Handler, Stack, Config) ->
maybe_done(Rest, handle_event({literal, true}, Handler, Config), Stack, Config);
true(_Bin, _Handler, _Stack, _Config) ->
erlang:error(badarg).
false(<<$a, $l, $s, $e, Rest/binary>>, Handler, Stack, Config) ->
maybe_done(Rest, handle_event({literal, false}, Handler, Config), Stack, Config);
false(_Bin, _Handler, _Stack, _Config) ->
erlang:error(badarg).
null(<<$u, $l, $l, Rest/binary>>, Handler, Stack, Config) ->
maybe_done(Rest, handle_event({literal, null}, Handler, Config), Stack, Config);
null(_Bin, _Handler, _Stack, _Config) ->
erlang:error(badarg).
maybe_done(<<Rest/binary>>, Handler, [], Config) ->
done(Rest, handle_event(end_json, Handler, Config), [], Config);
maybe_done(<<?end_object, Rest/binary>>, Handler, [object|Stack], Config) ->
maybe_done(Rest, handle_event(end_object, Handler, Config), Stack, Config);
maybe_done(<<?end_array, Rest/binary>>, Handler, [array|Stack], Config) ->
maybe_done(Rest, handle_event(end_array, Handler, Config), Stack, Config);
maybe_done(<<?comma, Rest/binary>>, Handler, [object|Stack], Config) ->
key(Rest, Handler, [key|Stack], Config);
maybe_done(<<?comma, Rest/binary>>, Handler, [array|_] = Stack, Config) ->
value(Rest, Handler, Stack, Config);
maybe_done(<<S, Rest/binary>>, Handler, Stack, Config) when ?is_whitespace(S) ->
maybe_done(Rest, Handler, Stack, Config);
maybe_done(_Bin, _Handler, _Stack, _Config) ->
erlang:error(badarg).
done(<<S, Rest/binary>>, Handler, [], Config) when ?is_whitespace(S) ->
done(Rest, Handler, [], Config);
done(<<>>, {_Handler, State}, [], _Config) -> State;
done(_Bin, _Handler, _Stack, _Config) -> erlang:error(badarg).

View file

@ -0,0 +1,567 @@
%%
%% Licensed to the Apache Software Foundation (ASF) under one
%% or more contributor license agreements. See the NOTICE file
%% distributed with this work for additional information
%% regarding copyright ownership. The ASF licenses this file
%% to you under the Apache License, Version 2.0 (the
%% "License"); you may not use this file except in compliance
%% with the License. You may obtain a copy of the License at
%%
%% http://www.apache.org/licenses/LICENSE-2.0
%%
%% Unless required by applicable law or agreed to in writing,
%% software distributed under the License is distributed on an
%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
%% KIND, either express or implied. See the License for the
%% specific language governing permissions and limitations
%% under the License.
%%
%% The JSON protocol implementation was created by
%% Peter Neumark <neumark.peter@gmail.com> based on
%% the binary protocol implementation.
-module(thrift_json_protocol).
-behaviour(thrift_protocol).
-include("thrift_constants.hrl").
-include("thrift_protocol.hrl").
-export([new/1, new/2,
read/2,
write/2,
flush_transport/1,
close_transport/1,
new_protocol_factory/2
]).
-record(json_context, {
% the type of json_context: array or object
type,
% fields read or written
fields_processed = 0
}).
-record(json_protocol, {
transport,
context_stack = [],
jsx
}).
-type state() :: #json_protocol{}.
-include("thrift_protocol_behaviour.hrl").
-define(VERSION_1, 1).
-define(JSON_DOUBLE_PRECISION, 16).
typeid_to_json(?tType_BOOL) -> "tf";
typeid_to_json(?tType_BYTE) -> "i8";
typeid_to_json(?tType_DOUBLE) -> "dbl";
typeid_to_json(?tType_I8) -> "i8";
typeid_to_json(?tType_I16) -> "i16";
typeid_to_json(?tType_I32) -> "i32";
typeid_to_json(?tType_I64) -> "i64";
typeid_to_json(?tType_STRING) -> "str";
typeid_to_json(?tType_STRUCT) -> "rec";
typeid_to_json(?tType_MAP) -> "map";
typeid_to_json(?tType_SET) -> "set";
typeid_to_json(?tType_LIST) -> "lst".
json_to_typeid("tf") -> ?tType_BOOL;
json_to_typeid("dbl") -> ?tType_DOUBLE;
json_to_typeid("i8") -> ?tType_I8;
json_to_typeid("i16") -> ?tType_I16;
json_to_typeid("i32") -> ?tType_I32;
json_to_typeid("i64") -> ?tType_I64;
json_to_typeid("str") -> ?tType_STRING;
json_to_typeid("rec") -> ?tType_STRUCT;
json_to_typeid("map") -> ?tType_MAP;
json_to_typeid("set") -> ?tType_SET;
json_to_typeid("lst") -> ?tType_LIST.
start_context(object) -> "{";
start_context(array) -> "[".
end_context(object) -> "}";
end_context(array) -> "]".
new(Transport) ->
new(Transport, _Options = []).
new(Transport, _Options) ->
State = #json_protocol{transport = Transport},
thrift_protocol:new(?MODULE, State).
flush_transport(This = #json_protocol{transport = Transport}) ->
{NewTransport, Result} = thrift_transport:flush(Transport),
{This#json_protocol{
transport = NewTransport,
context_stack = []
}, Result}.
close_transport(This = #json_protocol{transport = Transport}) ->
{NewTransport, Result} = thrift_transport:close(Transport),
{This#json_protocol{
transport = NewTransport,
context_stack = [],
jsx = undefined
}, Result}.
%%%
%%% instance methods
%%%
% places a new context on the stack:
write(#json_protocol{context_stack = Stack} = State0, {enter_context, Type}) ->
{State1, ok} = write_values(State0, [{context_pre_item, false}]),
State2 = State1#json_protocol{context_stack = [
#json_context{type=Type}|Stack]},
write_values(State2, [list_to_binary(start_context(Type))]);
% removes the topmost context from stack
write(#json_protocol{context_stack = [CurrCtxt|Stack]} = State0, {exit_context}) ->
Type = CurrCtxt#json_context.type,
State1 = State0#json_protocol{context_stack = Stack},
write_values(State1, [
list_to_binary(end_context(Type)),
{context_post_item, false}
]);
% writes necessary prelude to field or container depending on current context
write(#json_protocol{context_stack = []} = This0,
{context_pre_item, _}) -> {This0, ok};
write(#json_protocol{context_stack = [Context|_CtxtTail]} = This0,
{context_pre_item, MayNeedQuotes}) ->
FieldNo = Context#json_context.fields_processed,
CtxtType = Context#json_context.type,
Rem = FieldNo rem 2,
case {CtxtType, FieldNo, Rem, MayNeedQuotes} of
{array, N, _, _} when N > 0 -> % array element (not first)
write(This0, <<",">>);
{object, 0, _, true} -> % non-string object key (first)
write(This0, <<"\"">>);
{object, N, 0, true} when N > 0 -> % non-string object key (not first)
write(This0, <<",\"">>);
{object, N, 0, false} when N > 0-> % string object key (not first)
write(This0, <<",">>);
_ -> % no pre-field necessary
{This0, ok}
end;
% writes necessary postlude to field or container depending on current context
write(#json_protocol{context_stack = []} = This0,
{context_post_item, _}) -> {This0, ok};
write(#json_protocol{context_stack = [Context|CtxtTail]} = This0,
{context_post_item, MayNeedQuotes}) ->
FieldNo = Context#json_context.fields_processed,
CtxtType = Context#json_context.type,
Rem = FieldNo rem 2,
{This1, ok} = case {CtxtType, Rem, MayNeedQuotes} of
{object, 0, true} -> % non-string object key
write(This0, <<"\":">>);
{object, 0, false} -> % string object key
write(This0, <<":">>);
_ -> % no pre-field necessary
{This0, ok}
end,
NewContext = Context#json_context{fields_processed = FieldNo + 1},
{This1#json_protocol{context_stack=[NewContext|CtxtTail]}, ok};
write(This0, #protocol_message_begin{
name = Name,
type = Type,
seqid = Seqid}) ->
write_values(This0, [
{enter_context, array},
{i32, ?VERSION_1},
{string, Name},
{i32, Type},
{i32, Seqid}
]);
write(This, message_end) ->
write_values(This, [{exit_context}]);
% Example field expression: "1":{"dbl":3.14}
write(This0, #protocol_field_begin{
name = _Name,
type = Type,
id = Id}) ->
write_values(This0, [
% entering 'outer' object
{i16, Id},
% entering 'outer' object
{enter_context, object},
{string, typeid_to_json(Type)}
]);
write(This, field_stop) ->
{This, ok};
write(This, field_end) ->
write_values(This,[{exit_context}]);
% Example message with map: [1,"testMap",1,0,{"1":{"map":["i32","i32",3,{"7":77,"8":88,"9":99}]}}]
write(This0, #protocol_map_begin{
ktype = Ktype,
vtype = Vtype,
size = Size}) ->
write_values(This0, [
{enter_context, array},
{string, typeid_to_json(Ktype)},
{string, typeid_to_json(Vtype)},
{i32, Size},
{enter_context, object}
]);
write(This, map_end) ->
write_values(This,[
{exit_context},
{exit_context}
]);
write(This0, #protocol_list_begin{
etype = Etype,
size = Size}) ->
write_values(This0, [
{enter_context, array},
{string, typeid_to_json(Etype)},
{i32, Size}
]);
write(This, list_end) ->
write_values(This,[
{exit_context}
]);
% example message with set: [1,"testSet",1,0,{"1":{"set":["i32",3,1,2,3]}}]
write(This0, #protocol_set_begin{
etype = Etype,
size = Size}) ->
write_values(This0, [
{enter_context, array},
{string, typeid_to_json(Etype)},
{i32, Size}
]);
write(This, set_end) ->
write_values(This,[
{exit_context}
]);
% example message with struct: [1,"testStruct",1,0,{"1":{"rec":{"1":{"str":"worked"},"4":{"i8":1},"9":{"i32":1073741824},"11":{"i64":1152921504606847000}}}}]
write(This, #protocol_struct_begin{}) ->
write_values(This, [
{enter_context, object}
]);
write(This, struct_end) ->
write_values(This,[
{exit_context}
]);
write(This, {bool, true}) -> write_values(This, [
{context_pre_item, true},
<<"true">>,
{context_post_item, true}
]);
write(This, {bool, false}) -> write_values(This, [
{context_pre_item, true},
<<"false">>,
{context_post_item, true}
]);
write(This, {byte, Byte}) -> write_values(This, [
{context_pre_item, true},
list_to_binary(integer_to_list(Byte)),
{context_post_item, true}
]);
write(This, {i16, I16}) ->
write(This, {byte, I16});
write(This, {i32, I32}) ->
write(This, {byte, I32});
write(This, {i64, I64}) ->
write(This, {byte, I64});
write(This, {double, Double}) -> write_values(This, [
{context_pre_item, true},
list_to_binary(io_lib:format("~.*f", [?JSON_DOUBLE_PRECISION,Double])),
{context_post_item, true}
]);
write(This0, {string, Str}) -> write_values(This0, [
{context_pre_item, false},
case is_binary(Str) of
true -> Str;
false -> <<"\"", (list_to_binary(Str))/binary, "\"">>
end,
{context_post_item, false}
]);
%% TODO: binary fields should be base64 encoded?
%% Data :: iolist()
write(This = #json_protocol{transport = Trans}, Data) ->
%io:format("Data ~p Ctxt ~p~n~n", [Data, This#json_protocol.context_stack]),
{NewTransport, Result} = thrift_transport:write(Trans, Data),
{This#json_protocol{transport = NewTransport}, Result}.
write_values(This0, ValueList) ->
FinalState = lists:foldl(
fun(Val, ThisIn) ->
{ThisOut, ok} = write(ThisIn, Val),
ThisOut
end,
This0,
ValueList),
{FinalState, ok}.
%% I wish the erlang version of the transport interface included a
%% read_all function (like eg. the java implementation). Since it doesn't,
%% here's my version (even though it probably shouldn't be in this file).
%%
%% The resulting binary is immediately send to the JSX stream parser.
%% Subsequent calls to read actually operate on the events returned by JSX.
read_all(#json_protocol{transport = Transport0} = State) ->
{Transport1, Bin} = read_all_1(Transport0, []),
P = thrift_json_parser:parser(),
[First|Rest] = P(Bin),
State#json_protocol{
transport = Transport1,
jsx = {event, First, Rest}
}.
read_all_1(Transport0, IoList) ->
{Transport1, Result} = thrift_transport:read(Transport0, 1),
case Result of
{ok, <<>>} -> % nothing read: assume we're done
{Transport1, iolist_to_binary(lists:reverse(IoList))};
{ok, Data} -> % character successfully read; read more
read_all_1(Transport1, [Data|IoList]);
{error, 'EOF'} -> % we're done
{Transport1, iolist_to_binary(lists:reverse(IoList))}
end.
% Expect reads an event from the JSX event stream. It receives an event or data
% type as input. Comparing the read event from the one is was passed, it
% returns an error if something other than the expected value is encountered.
% Expect also maintains the context stack in #json_protocol.
expect(#json_protocol{jsx={event, {Type, Data}=Ev, [Next|Rest]}}=State, ExpectedType) ->
NextState = State#json_protocol{jsx={event, Next, Rest}},
case Type == ExpectedType of
true ->
{NextState, {ok, convert_data(Type, Data)}};
false ->
{NextState, {error, {unexpected_json_event, Ev}}}
end;
expect(#json_protocol{jsx={event, Event, Next}}=State, ExpectedEvent) ->
expect(State#json_protocol{jsx={event, {Event, none}, Next}}, ExpectedEvent).
convert_data(integer, I) -> list_to_integer(I);
convert_data(float, F) -> list_to_float(F);
convert_data(_, D) -> D.
expect_many(State, ExpectedList) ->
expect_many_1(State, ExpectedList, [], ok).
expect_many_1(State, [], ResultList, Status) ->
{State, {Status, lists:reverse(ResultList)}};
expect_many_1(State, [Expected|ExpTail], ResultList, _PrevStatus) ->
{State1, {Status, Data}} = expect(State, Expected),
NewResultList = [Data|ResultList],
case Status of
% in case of error, end prematurely
error -> expect_many_1(State1, [], NewResultList, Status);
ok -> expect_many_1(State1, ExpTail, NewResultList, Status)
end.
% wrapper around expect to make life easier for container opening/closing functions
expect_nodata(This, ExpectedList) ->
case expect_many(This, ExpectedList) of
{State, {ok, _}} ->
{State, ok};
Error ->
Error
end.
read_field(#json_protocol{jsx={event, Field, [Next|Rest]}} = State) ->
NewState = State#json_protocol{jsx={event, Next, Rest}},
{NewState, Field}.
read(This0, message_begin) ->
% call read_all to get the contents of the transport buffer into JSX.
This1 = read_all(This0),
case expect_many(This1,
[start_array, integer, string, integer, integer]) of
{This2, {ok, [_, Version, Name, Type, SeqId]}} ->
case Version =:= ?VERSION_1 of
true ->
{This2, #protocol_message_begin{name = Name,
type = Type,
seqid = SeqId}};
false ->
{This2, {error, no_json_protocol_version}}
end;
Other -> Other
end;
read(This, message_end) ->
expect_nodata(This, [end_array]);
read(This, struct_begin) ->
expect_nodata(This, [start_object]);
read(This, struct_end) ->
expect_nodata(This, [end_object]);
read(This0, field_begin) ->
{This1, Read} = expect_many(This0,
[%field id
key,
% {} surrounding field
start_object,
% type of field
key]),
case Read of
{ok, [FieldIdStr, _, FieldType]} ->
{This1, #protocol_field_begin{
type = json_to_typeid(FieldType),
id = list_to_integer(FieldIdStr)}}; % TODO: do we need to wrap this in a try/catch?
{error,[{unexpected_json_event, {end_object,none}}]} ->
{This1, #protocol_field_begin{type = ?tType_STOP}};
Other ->
io:format("**** OTHER branch selected ****"),
{This1, Other}
end;
read(This, field_end) ->
expect_nodata(This, [end_object]);
% Example message with map: [1,"testMap",1,0,{"1":{"map":["i32","i32",3,{"7":77,"8":88,"9":99}]}}]
read(This0, map_begin) ->
case expect_many(This0,
[start_array,
% key type
string,
% value type
string,
% size
integer,
% the following object contains the map
start_object]) of
{This1, {ok, [_, Ktype, Vtype, Size, _]}} ->
{This1, #protocol_map_begin{ktype = Ktype,
vtype = Vtype,
size = Size}};
Other -> Other
end;
read(This, map_end) ->
expect_nodata(This, [end_object, end_array]);
read(This0, list_begin) ->
case expect_many(This0,
[start_array,
% element type
string,
% size
integer]) of
{This1, {ok, [_, Etype, Size]}} ->
{This1, #protocol_list_begin{
etype = Etype,
size = Size}};
Other -> Other
end;
read(This, list_end) ->
expect_nodata(This, [end_array]);
% example message with set: [1,"testSet",1,0,{"1":{"set":["i32",3,1,2,3]}}]
read(This0, set_begin) ->
case expect_many(This0,
[start_array,
% element type
string,
% size
integer]) of
{This1, {ok, [_, Etype, Size]}} ->
{This1, #protocol_set_begin{
etype = Etype,
size = Size}};
Other -> Other
end;
read(This, set_end) ->
expect_nodata(This, [end_array]);
read(This0, field_stop) ->
{This0, ok};
%%
read(This0, bool) ->
{This1, Field} = read_field(This0),
Value = case Field of
{literal, I} ->
{ok, I};
_Other ->
{error, unexpected_event_for_boolean}
end,
{This1, Value};
read(This0, byte) ->
{This1, Field} = read_field(This0),
Value = case Field of
{key, K} ->
{ok, list_to_integer(K)};
{integer, I} ->
{ok, list_to_integer(I)};
_Other ->
{error, unexpected_event_for_integer}
end,
{This1, Value};
read(This0, i16) ->
read(This0, byte);
read(This0, i32) ->
read(This0, byte);
read(This0, i64) ->
read(This0, byte);
read(This0, double) ->
{This1, Field} = read_field(This0),
Value = case Field of
{float, I} ->
{ok, list_to_float(I)};
_Other ->
{error, unexpected_event_for_double}
end,
{This1, Value};
% returns a binary directly, call binary_to_list if necessary
read(This0, string) ->
{This1, Field} = read_field(This0),
Value = case Field of
{string, I} ->
{ok, I};
{key, J} ->
{ok, J};
_Other ->
{error, unexpected_event_for_string}
end,
{This1, Value}.
%%%% FACTORY GENERATION %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%% returns a (fun() -> thrift_protocol())
new_protocol_factory(TransportFactory, _Options) ->
% Only strice read/write are implemented
F = fun() ->
{ok, Transport} = TransportFactory(),
thrift_json_protocol:new(Transport, [])
end,
{ok, F}.

View file

@ -0,0 +1,83 @@
%%
%% Licensed to the Apache Software Foundation (ASF) under one
%% or more contributor license agreements. See the NOTICE file
%% distributed with this work for additional information
%% regarding copyright ownership. The ASF licenses this file
%% to you under the Apache License, Version 2.0 (the
%% "License"); you may not use this file except in compliance
%% with the License. You may obtain a copy of the License at
%%
%% http://www.apache.org/licenses/LICENSE-2.0
%%
%% Unless required by applicable law or agreed to in writing,
%% software distributed under the License is distributed on an
%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
%% KIND, either express or implied. See the License for the
%% specific language governing permissions and limitations
%% under the License.
%%
-module(thrift_membuffer_transport).
-behaviour(thrift_transport).
%% constructors
-export([new/0, new/1]).
%% protocol callbacks
-export([read/2, read_exact/2, write/2, flush/1, close/1]).
-record(t_membuffer, {
buffer = []
}).
-type state() :: #t_membuffer{}.
-spec new() -> thrift_transport:t_transport().
new() -> new([]).
-spec new(Buf::iodata()) -> thrift_transport:t_transport().
new(Buf) when is_list(Buf) ->
State = #t_membuffer{buffer = Buf},
thrift_transport:new(?MODULE, State);
new(Buf) when is_binary(Buf) ->
State = #t_membuffer{buffer = [Buf]},
thrift_transport:new(?MODULE, State).
-include("thrift_transport_behaviour.hrl").
read(State = #t_membuffer{buffer = Buf}, Len)
when is_integer(Len), Len >= 0 ->
Binary = iolist_to_binary(Buf),
Give = min(iolist_size(Binary), Len),
{Result, Remaining} = split_binary(Binary, Give),
{State#t_membuffer{buffer = Remaining}, {ok, Result}}.
read_exact(State = #t_membuffer{buffer = Buf}, Len)
when is_integer(Len), Len >= 0 ->
Binary = iolist_to_binary(Buf),
case iolist_size(Binary) of
X when X >= Len ->
{Result, Remaining} = split_binary(Binary, Len),
{State#t_membuffer{buffer = Remaining}, {ok, Result}};
_ ->
{State, {error, eof}}
end.
write(State = #t_membuffer{buffer = Buf}, Data)
when is_list(Data); is_binary(Data) ->
{State#t_membuffer{buffer = [Buf, Data]}, ok}.
flush(State) -> {State, ok}.
close(State) -> {State, ok}.

View file

@ -0,0 +1,47 @@
%%
%% Licensed to the Apache Software Foundation (ASF) under one
%% or more contributor license agreements. See the NOTICE file
%% distributed with this work for additional information
%% regarding copyright ownership. The ASF licenses this file
%% to you under the Apache License, Version 2.0 (the
%% "License"); you may not use this file except in compliance
%% with the License. You may obtain a copy of the License at
%%
%% http://www.apache.org/licenses/LICENSE-2.0
%%
%% Unless required by applicable law or agreed to in writing,
%% software distributed under the License is distributed on an
%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
%% KIND, either express or implied. See the License for the
%% specific language governing permissions and limitations
%% under the License.
%%
-module(thrift_memory_buffer).
-behaviour(thrift_transport).
%% constructors
-export([new/0, new/1]).
%% protocol callbacks
-export([read/2, write/2, flush/1, close/1]).
%% legacy api
-export([new_transport_factory/0]).
%% wrapper around thrift_membuffer_transport for legacy reasons
new() -> thrift_membuffer_transport:new().
new(State) -> thrift_membuffer_transport:new(State).
new_transport_factory() -> {ok, fun() -> new() end}.
write(State, Data) -> thrift_membuffer_transport:write(State, Data).
read(State, Data) -> thrift_membuffer_transport:read(State, Data).
flush(State) -> thrift_membuffer_transport:flush(State).
close(State) -> thrift_membuffer_transport:close(State).

View file

@ -0,0 +1,57 @@
%%
%% Licensed to the Apache Software Foundation (ASF) under one
%% or more contributor license agreements. See the NOTICE file
%% distributed with this work for additional information
%% regarding copyright ownership. The ASF licenses this file
%% to you under the Apache License, Version 2.0 (the
%% "License"); you may not use this file except in compliance
%% with the License. You may obtain a copy of the License at
%%
%% http://www.apache.org/licenses/LICENSE-2.0
%%
%% Unless required by applicable law or agreed to in writing,
%% software distributed under the License is distributed on an
%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
%% KIND, either express or implied. See the License for the
%% specific language governing permissions and limitations
%% under the License.
%%
-module(thrift_multiplexed_map_wrapper).
-export([
new/0
,store/3
,find/2
,fetch/2
]).
-type service_handler() :: nonempty_string().
-type module_() :: atom().
-type service_handler_map() :: [{ServiceHandler::service_handler(), Module::module_()}].
-spec new() -> service_handler_map().
new() ->
orddict:new().
-spec store(ServiceHandler, Module, Map) -> NewMap when
ServiceHandler :: service_handler(),
Module :: module_(),
Map :: service_handler_map(),
NewMap :: service_handler_map().
store(ServiceHandler, Module, Map) ->
orddict:store(ServiceHandler, Module, Map).
-spec find(ServiceHandler, Map) -> {ok, Module} | error when
ServiceHandler :: service_handler(),
Module :: module_(),
Map :: service_handler_map().
find(ServiceHandler, Map) ->
orddict:find(ServiceHandler, Map).
-spec fetch(ServiceHandler, Map) -> Module when
ServiceHandler :: service_handler(),
Module :: module_(),
Map :: service_handler_map().
fetch(ServiceHandler, Map) ->
orddict:fetch(ServiceHandler, Map).

View file

@ -0,0 +1,83 @@
%%
%% Licensed to the Apache Software Foundation (ASF) under one
%% or more contributor license agreements. See the NOTICE file
%% distributed with this work for additional information
%% regarding copyright ownership. The ASF licenses this file
%% to you under the Apache License, Version 2.0 (the
%% "License"); you may not use this file except in compliance
%% with the License. You may obtain a copy of the License at
%%
%% http://www.apache.org/licenses/LICENSE-2.0
%%
%% Unless required by applicable law or agreed to in writing,
%% software distributed under the License is distributed on an
%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
%% KIND, either express or implied. See the License for the
%% specific language governing permissions and limitations
%% under the License.
%%
-module(thrift_multiplexed_protocol).
-behaviour(thrift_protocol).
-include("thrift_constants.hrl").
-include("thrift_protocol.hrl").
-include("thrift_protocol_behaviour.hrl").
-export([new/2,
read/2,
write/2,
flush_transport/1,
close_transport/1
]).
-record(protocol, {module, data}).
-type protocol() :: #protocol{}.
-record (multiplexed_protocol, {protocol_module_to_decorate::atom(),
protocol_data_to_decorate::term(),
service_name::nonempty_string()}).
-type state() :: #multiplexed_protocol{}.
-spec new(ProtocolToDecorate::protocol(), ServiceName::nonempty_string()) -> {ok, Protocol::protocol()}.
new(ProtocolToDecorate, ServiceName) when is_record(ProtocolToDecorate, protocol),
is_list(ServiceName) ->
State = #multiplexed_protocol{protocol_module_to_decorate = ProtocolToDecorate#protocol.module,
protocol_data_to_decorate = ProtocolToDecorate#protocol.data,
service_name = ServiceName},
thrift_protocol:new(?MODULE, State).
flush_transport(State = #multiplexed_protocol{protocol_module_to_decorate = ProtocolModuleToDecorate,
protocol_data_to_decorate = State0}) ->
{State1, ok} = ProtocolModuleToDecorate:flush_transport(State0),
{State#multiplexed_protocol{protocol_data_to_decorate = State1}, ok}.
close_transport(State = #multiplexed_protocol{protocol_module_to_decorate = ProtocolModuleToDecorate,
protocol_data_to_decorate = State0}) ->
{State1, ok} = ProtocolModuleToDecorate:close_transport(State0),
{State#multiplexed_protocol{protocol_data_to_decorate = State1}, ok}.
write(State = #multiplexed_protocol{protocol_module_to_decorate = ProtocolModuleToDecorate,
protocol_data_to_decorate = State0,
service_name = ServiceName},
Message = #protocol_message_begin{name = Name}) ->
{State1, ok} = ProtocolModuleToDecorate:write(State0,
Message#protocol_message_begin{name=ServiceName ++
?MULTIPLEXED_SERVICE_SEPARATOR ++
Name}),
{State#multiplexed_protocol{protocol_data_to_decorate = State1}, ok};
write(State = #multiplexed_protocol{protocol_module_to_decorate = ProtocolModuleToDecorate,
protocol_data_to_decorate = State0},
Message) ->
{State1, ok} = ProtocolModuleToDecorate:write(State0, Message),
{State#multiplexed_protocol{protocol_data_to_decorate = State1}, ok}.
read(State = #multiplexed_protocol{protocol_module_to_decorate = ProtocolModuleToDecorate,
protocol_data_to_decorate = State0},
Message) ->
{State1, Result} = ProtocolModuleToDecorate:read(State0, Message),
{State#multiplexed_protocol{protocol_data_to_decorate = State1}, Result}.

View file

@ -0,0 +1,219 @@
%%
%% Licensed to the Apache Software Foundation (ASF) under one
%% or more contributor license agreements. See the NOTICE file
%% distributed with this work for additional information
%% regarding copyright ownership. The ASF licenses this file
%% to you under the Apache License, Version 2.0 (the
%% "License"); you may not use this file except in compliance
%% with the License. You may obtain a copy of the License at
%%
%% http://www.apache.org/licenses/LICENSE-2.0
%%
%% Unless required by applicable law or agreed to in writing,
%% software distributed under the License is distributed on an
%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
%% KIND, either express or implied. See the License for the
%% specific language governing permissions and limitations
%% under the License.
%%
-module(thrift_processor).
-export([init/1]).
-include("thrift_constants.hrl").
-include("thrift_protocol.hrl").
-record(thrift_processor, {handler, protocol, service}).
init({_Server, ProtoGen, Service, Handler}) when is_function(ProtoGen, 0) ->
{ok, Proto} = ProtoGen(),
loop(#thrift_processor{protocol = Proto,
service = Service,
handler = Handler}).
loop(State0 = #thrift_processor{protocol = Proto0,
handler = Handler,
service = Service}) ->
{Proto1, MessageBegin} = thrift_protocol:read(Proto0, message_begin),
State1 = State0#thrift_processor{protocol = Proto1},
ErrorHandler = fun
(HandlerModules) when is_list(HandlerModules) -> thrift_multiplexed_map_wrapper:fetch(?MULTIPLEXED_ERROR_HANDLER_KEY, HandlerModules);
(HandlerModule) -> HandlerModule
end,
case MessageBegin of
#protocol_message_begin{name = Function,
type = Type,
seqid = Seqid} when Type =:= ?tMessageType_CALL; Type =:= ?tMessageType_ONEWAY ->
case string:tokens(Function, ?MULTIPLEXED_SERVICE_SEPARATOR) of
[ServiceName, FunctionName] ->
ServiceModule = thrift_multiplexed_map_wrapper:fetch(ServiceName, Service),
ServiceHandler = thrift_multiplexed_map_wrapper:fetch(ServiceName, Handler),
case handle_function(State1#thrift_processor{service=ServiceModule, handler=ServiceHandler}, list_to_atom(FunctionName), Seqid) of
{State2, ok} -> loop(State2#thrift_processor{service=Service, handler=Handler});
{_State2, {error, Reason}} ->
apply(ErrorHandler(Handler), handle_error, [list_to_atom(Function), Reason]),
thrift_protocol:close_transport(Proto1),
ok
end;
_ ->
case handle_function(State1, list_to_atom(Function), Seqid) of
{State2, ok} -> loop(State2);
{_State2, {error, Reason}} ->
apply(ErrorHandler(Handler), handle_error, [list_to_atom(Function), Reason]),
thrift_protocol:close_transport(Proto1),
ok
end
end;
{error, timeout = Reason} ->
apply(ErrorHandler(Handler), handle_error, [undefined, Reason]),
thrift_protocol:close_transport(Proto1),
ok;
{error, closed = Reason} ->
%% error_logger:info_msg("Client disconnected~n"),
apply(ErrorHandler(Handler), handle_error, [undefined, Reason]),
thrift_protocol:close_transport(Proto1),
exit(shutdown);
{error, Reason} ->
apply(ErrorHandler(Handler), handle_error, [undefined, Reason]),
thrift_protocol:close_transport(Proto1),
exit(shutdown)
end.
handle_function(State0=#thrift_processor{protocol = Proto0,
handler = Handler,
service = Service},
Function,
Seqid) ->
InParams = Service:function_info(Function, params_type),
{Proto1, {ok, Params}} = thrift_protocol:read(Proto0, InParams),
State1 = State0#thrift_processor{protocol = Proto1},
try
Result = Handler:handle_function(Function, Params),
%% {Micro, Result} = better_timer(Handler, handle_function, [Function, Params]),
%% error_logger:info_msg("Processed ~p(~p) in ~.4fms~n",
%% [Function, Params, Micro/1000.0]),
handle_success(State1, Function, Result, Seqid)
catch
Type:Data when Type =:= throw orelse Type =:= error ->
handle_function_catch(State1, Function, Type, Data, Seqid)
end.
handle_function_catch(State = #thrift_processor{service = Service},
Function, ErrType, ErrData, Seqid) ->
IsOneway = Service:function_info(Function, reply_type) =:= oneway_void,
case {ErrType, ErrData} of
_ when IsOneway ->
Stack = erlang:get_stacktrace(),
error_logger:warning_msg(
"oneway void ~p threw error which must be ignored: ~p",
[Function, {ErrType, ErrData, Stack}]),
{State, ok};
{throw, Exception} when is_tuple(Exception), size(Exception) > 0 ->
%error_logger:warning_msg("~p threw exception: ~p~n", [Function, Exception]),
handle_exception(State, Function, Exception, Seqid);
% we still want to accept more requests from this client
{error, Error} ->
handle_error(State, Function, Error, Seqid)
end.
handle_success(State = #thrift_processor{service = Service},
Function,
Result,
Seqid) ->
ReplyType = Service:function_info(Function, reply_type),
StructName = atom_to_list(Function) ++ "_result",
case Result of
{reply, ReplyData} ->
Reply = {{struct, [{0, ReplyType}]}, {StructName, ReplyData}},
send_reply(State, Function, ?tMessageType_REPLY, Reply, Seqid);
ok when ReplyType == {struct, []} ->
send_reply(State, Function, ?tMessageType_REPLY, {ReplyType, {StructName}}, Seqid);
ok when ReplyType == oneway_void ->
%% no reply for oneway void
{State, ok}
end.
handle_exception(State = #thrift_processor{service = Service},
Function,
Exception,
Seqid) ->
ExceptionType = element(1, Exception),
%% Fetch a structure like {struct, [{-2, {struct, {Module, Type}}},
%% {-3, {struct, {Module, Type}}}]}
ReplySpec = Service:function_info(Function, exceptions),
{struct, XInfo} = ReplySpec,
true = is_list(XInfo),
%% Assuming we had a type1 exception, we'd get: [undefined, Exception, undefined]
%% e.g.: [{-1, type0}, {-2, type1}, {-3, type2}]
ExceptionList = [case Type of
ExceptionType -> Exception;
_ -> undefined
end
|| {_Fid, {struct, {_Module, Type}}} <- XInfo],
ExceptionTuple = list_to_tuple([Function | ExceptionList]),
% Make sure we got at least one defined
case lists:all(fun(X) -> X =:= undefined end, ExceptionList) of
true ->
handle_unknown_exception(State, Function, Exception, Seqid);
false ->
send_reply(State, Function, ?tMessageType_REPLY, {ReplySpec, ExceptionTuple}, Seqid)
end.
%%
%% Called when an exception has been explicitly thrown by the service, but it was
%% not one of the exceptions that was defined for the function.
%%
handle_unknown_exception(State, Function, Exception, Seqid) ->
handle_error(State, Function, {exception_not_declared_as_thrown,
Exception}, Seqid).
handle_error(State, Function, Error, Seqid) ->
Stack = erlang:get_stacktrace(),
error_logger:error_msg("~p had an error: ~p~n", [Function, {Error, Stack}]),
Message =
case application:get_env(thrift, exceptions_include_traces) of
{ok, true} ->
lists:flatten(io_lib:format("An error occurred: ~p~n",
[{Error, Stack}]));
_ ->
"An unknown handler error occurred."
end,
Reply = {?TApplicationException_Structure,
#'TApplicationException'{
message = Message,
type = ?TApplicationException_UNKNOWN}},
send_reply(State, Function, ?tMessageType_EXCEPTION, Reply, Seqid).
send_reply(State = #thrift_processor{protocol = Proto0}, Function, ReplyMessageType, Reply, Seqid) ->
try
{Proto1, ok} = thrift_protocol:write(Proto0, #protocol_message_begin{
name = atom_to_list(Function),
type = ReplyMessageType,
seqid = Seqid}),
{Proto2, ok} = thrift_protocol:write(Proto1, Reply),
{Proto3, ok} = thrift_protocol:write(Proto2, message_end),
{Proto4, ok} = thrift_protocol:flush_transport(Proto3),
{State#thrift_processor{protocol = Proto4}, ok}
catch
error:{badmatch, {_, {error, _} = Error}} ->
{State, Error}
end.

View file

@ -0,0 +1,413 @@
%%
%% Licensed to the Apache Software Foundation (ASF) under one
%% or more contributor license agreements. See the NOTICE file
%% distributed with this work for additional information
%% regarding copyright ownership. The ASF licenses this file
%% to you under the Apache License, Version 2.0 (the
%% "License"); you may not use this file except in compliance
%% with the License. You may obtain a copy of the License at
%%
%% http://www.apache.org/licenses/LICENSE-2.0
%%
%% Unless required by applicable law or agreed to in writing,
%% software distributed under the License is distributed on an
%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
%% KIND, either express or implied. See the License for the
%% specific language governing permissions and limitations
%% under the License.
%%
-module(thrift_protocol).
-export([new/2,
write/2,
read/2,
read/3,
skip/2,
flush_transport/1,
close_transport/1,
typeid_to_atom/1
]).
-export([behaviour_info/1]).
-include("thrift_constants.hrl").
-include("thrift_protocol.hrl").
-record(protocol, {module, data}).
behaviour_info(callbacks) ->
[
{read, 2},
{write, 2},
{flush_transport, 1},
{close_transport, 1}
];
behaviour_info(_Else) -> undefined.
new(Module, Data) when is_atom(Module) ->
{ok, #protocol{module = Module,
data = Data}}.
-spec flush_transport(#protocol{}) -> {#protocol{}, ok}.
flush_transport(Proto = #protocol{module = Module,
data = Data}) ->
{NewData, Result} = Module:flush_transport(Data),
{Proto#protocol{data = NewData}, Result}.
-spec close_transport(#protocol{}) -> ok.
close_transport(#protocol{module = Module,
data = Data}) ->
Module:close_transport(Data).
typeid_to_atom(?tType_STOP) -> field_stop;
typeid_to_atom(?tType_VOID) -> void;
typeid_to_atom(?tType_BOOL) -> bool;
typeid_to_atom(?tType_DOUBLE) -> double;
typeid_to_atom(?tType_I8) -> byte;
typeid_to_atom(?tType_I16) -> i16;
typeid_to_atom(?tType_I32) -> i32;
typeid_to_atom(?tType_I64) -> i64;
typeid_to_atom(?tType_STRING) -> string;
typeid_to_atom(?tType_STRUCT) -> struct;
typeid_to_atom(?tType_MAP) -> map;
typeid_to_atom(?tType_SET) -> set;
typeid_to_atom(?tType_LIST) -> list.
term_to_typeid(void) -> ?tType_VOID;
term_to_typeid(bool) -> ?tType_BOOL;
term_to_typeid(byte) -> ?tType_I8;
term_to_typeid(double) -> ?tType_DOUBLE;
term_to_typeid(i8) -> ?tType_I8;
term_to_typeid(i16) -> ?tType_I16;
term_to_typeid(i32) -> ?tType_I32;
term_to_typeid(i64) -> ?tType_I64;
term_to_typeid(string) -> ?tType_STRING;
term_to_typeid({struct, _}) -> ?tType_STRUCT;
term_to_typeid({map, _, _}) -> ?tType_MAP;
term_to_typeid({set, _}) -> ?tType_SET;
term_to_typeid({list, _}) -> ?tType_LIST.
%% Structure is like:
%% [{Fid, Type}, ...]
-spec read(#protocol{}, {struct, _StructDef}, atom()) -> {#protocol{}, {ok, tuple()}}.
read(IProto0, {struct, Structure}, Tag)
when is_list(Structure), is_atom(Tag) ->
% If we want a tagged tuple, we need to offset all the tuple indices
% by 1 to avoid overwriting the tag.
Offset = if Tag =/= undefined -> 1; true -> 0 end,
IndexList = case length(Structure) of
N when N > 0 -> lists:seq(1 + Offset, N + Offset);
_ -> []
end,
SWithIndices = [{Fid, {Type, Index}} ||
{{Fid, Type}, Index} <-
lists:zip(Structure, IndexList)],
% Fid -> {Type, Index}
SDict = dict:from_list(SWithIndices),
{IProto1, ok} = read(IProto0, struct_begin),
RTuple0 = erlang:make_tuple(length(Structure) + Offset, undefined),
RTuple1 = if Tag =/= undefined -> setelement(1, RTuple0, Tag);
true -> RTuple0
end,
{IProto2, RTuple2} = read_struct_loop(IProto1, SDict, RTuple1),
{IProto2, {ok, RTuple2}}.
%% NOTE: Keep this in sync with thrift_protocol_behaviour:read
-spec read
(#protocol{}, {struct, _Info}) -> {#protocol{}, {ok, tuple()} | {error, _Reason}};
(#protocol{}, tprot_cont_tag()) -> {#protocol{}, {ok, any()} | {error, _Reason}};
(#protocol{}, tprot_empty_tag()) -> {#protocol{}, ok | {error, _Reason}};
(#protocol{}, tprot_header_tag()) -> {#protocol{}, tprot_header_val() | {error, _Reason}};
(#protocol{}, tprot_data_tag()) -> {#protocol{}, {ok, any()} | {error, _Reason}}.
read(IProto, {struct, {Module, StructureName}}) when is_atom(Module),
is_atom(StructureName) ->
read(IProto, Module:struct_info(StructureName), StructureName);
read(IProto, S={struct, Structure}) when is_list(Structure) ->
read(IProto, S, undefined);
read(IProto0, {list, Type}) ->
{IProto1, #protocol_list_begin{etype = EType, size = Size}} =
read(IProto0, list_begin),
{EType, EType} = {term_to_typeid(Type), EType},
{List, IProto2} = lists:mapfoldl(fun(_, ProtoS0) ->
{ProtoS1, {ok, Item}} = read(ProtoS0, Type),
{Item, ProtoS1}
end,
IProto1,
lists:duplicate(Size, 0)),
{IProto3, ok} = read(IProto2, list_end),
{IProto3, {ok, List}};
read(IProto0, {map, KeyType, ValType}) ->
{IProto1, #protocol_map_begin{size = Size, ktype = KType, vtype = VType}} =
read(IProto0, map_begin),
_ = case Size of
0 -> 0;
_ ->
{KType, KType} = {term_to_typeid(KeyType), KType},
{VType, VType} = {term_to_typeid(ValType), VType}
end,
{List, IProto2} = lists:mapfoldl(fun(_, ProtoS0) ->
{ProtoS1, {ok, Key}} = read(ProtoS0, KeyType),
{ProtoS2, {ok, Val}} = read(ProtoS1, ValType),
{{Key, Val}, ProtoS2}
end,
IProto1,
lists:duplicate(Size, 0)),
{IProto3, ok} = read(IProto2, map_end),
{IProto3, {ok, dict:from_list(List)}};
read(IProto0, {set, Type}) ->
{IProto1, #protocol_set_begin{etype = EType, size = Size}} =
read(IProto0, set_begin),
{EType, EType} = {term_to_typeid(Type), EType},
{List, IProto2} = lists:mapfoldl(fun(_, ProtoS0) ->
{ProtoS1, {ok, Item}} = read(ProtoS0, Type),
{Item, ProtoS1}
end,
IProto1,
lists:duplicate(Size, 0)),
{IProto3, ok} = read(IProto2, set_end),
{IProto3, {ok, sets:from_list(List)}};
read(Protocol, ProtocolType) ->
read_specific(Protocol, ProtocolType).
%% NOTE: Keep this in sync with thrift_protocol_behaviour:read
-spec read_specific
(#protocol{}, tprot_empty_tag()) -> {#protocol{}, ok | {error, _Reason}};
(#protocol{}, tprot_header_tag()) -> {#protocol{}, tprot_header_val() | {error, _Reason}};
(#protocol{}, tprot_data_tag()) -> {#protocol{}, {ok, any()} | {error, _Reason}}.
read_specific(Proto = #protocol{module = Module,
data = ModuleData}, ProtocolType) ->
{NewData, Result} = Module:read(ModuleData, ProtocolType),
{Proto#protocol{data = NewData}, Result}.
read_struct_loop(IProto0, SDict, RTuple) ->
{IProto1, #protocol_field_begin{type = FType, id = Fid}} =
thrift_protocol:read(IProto0, field_begin),
case {FType, Fid} of
{?tType_STOP, _} ->
{IProto2, ok} = read(IProto1, struct_end),
{IProto2, RTuple};
_Else ->
case dict:find(Fid, SDict) of
{ok, {Type, Index}} ->
case term_to_typeid(Type) of
FType ->
{IProto2, {ok, Val}} = read(IProto1, Type),
{IProto3, ok} = thrift_protocol:read(IProto2, field_end),
NewRTuple = setelement(Index, RTuple, Val),
read_struct_loop(IProto3, SDict, NewRTuple);
Expected ->
error_logger:info_msg(
"Skipping field ~p with wrong type (~p != ~p)~n",
[Fid, FType, Expected]),
skip_field(FType, IProto1, SDict, RTuple)
end;
_Else2 ->
skip_field(FType, IProto1, SDict, RTuple)
end
end.
skip_field(FType, IProto0, SDict, RTuple) ->
FTypeAtom = thrift_protocol:typeid_to_atom(FType),
{IProto1, ok} = thrift_protocol:skip(IProto0, FTypeAtom),
{IProto2, ok} = read(IProto1, field_end),
read_struct_loop(IProto2, SDict, RTuple).
-spec skip(#protocol{}, any()) -> {#protocol{}, ok}.
skip(Proto0, struct) ->
{Proto1, ok} = read(Proto0, struct_begin),
{Proto2, ok} = skip_struct_loop(Proto1),
{Proto3, ok} = read(Proto2, struct_end),
{Proto3, ok};
skip(Proto0, map) ->
{Proto1, Map} = read(Proto0, map_begin),
{Proto2, ok} = skip_map_loop(Proto1, Map),
{Proto3, ok} = read(Proto2, map_end),
{Proto3, ok};
skip(Proto0, set) ->
{Proto1, Set} = read(Proto0, set_begin),
{Proto2, ok} = skip_set_loop(Proto1, Set),
{Proto3, ok} = read(Proto2, set_end),
{Proto3, ok};
skip(Proto0, list) ->
{Proto1, List} = read(Proto0, list_begin),
{Proto2, ok} = skip_list_loop(Proto1, List),
{Proto3, ok} = read(Proto2, list_end),
{Proto3, ok};
skip(Proto0, Type) when is_atom(Type) ->
{Proto1, _Ignore} = read(Proto0, Type),
{Proto1, ok}.
skip_struct_loop(Proto0) ->
{Proto1, #protocol_field_begin{type = Type}} = read(Proto0, field_begin),
case Type of
?tType_STOP ->
{Proto1, ok};
_Else ->
{Proto2, ok} = skip(Proto1, Type),
{Proto3, ok} = read(Proto2, field_end),
skip_struct_loop(Proto3)
end.
skip_map_loop(Proto0, Map = #protocol_map_begin{ktype = Ktype,
vtype = Vtype,
size = Size}) ->
case Size of
N when N > 0 ->
{Proto1, ok} = skip(Proto0, Ktype),
{Proto2, ok} = skip(Proto1, Vtype),
skip_map_loop(Proto2,
Map#protocol_map_begin{size = Size - 1});
0 -> {Proto0, ok}
end.
skip_set_loop(Proto0, Map = #protocol_set_begin{etype = Etype,
size = Size}) ->
case Size of
N when N > 0 ->
{Proto1, ok} = skip(Proto0, Etype),
skip_set_loop(Proto1,
Map#protocol_set_begin{size = Size - 1});
0 -> {Proto0, ok}
end.
skip_list_loop(Proto0, Map = #protocol_list_begin{etype = Etype,
size = Size}) ->
case Size of
N when N > 0 ->
{Proto1, ok} = skip(Proto0, Etype),
skip_list_loop(Proto1,
Map#protocol_list_begin{size = Size - 1});
0 -> {Proto0, ok}
end.
%%--------------------------------------------------------------------
%% Function: write(OProto, {Type, Data}) -> ok
%%
%% Type = {struct, StructDef} |
%% {list, Type} |
%% {map, KeyType, ValType} |
%% {set, Type} |
%% BaseType
%%
%% Data =
%% tuple() -- for struct
%% | list() -- for list
%% | dictionary() -- for map
%% | set() -- for set
%% | any() -- for base types
%%
%% Description:
%%--------------------------------------------------------------------
-spec write(#protocol{}, any()) -> {#protocol{}, ok | {error, _Reason}}.
write(Proto0, {{struct, StructDef}, Data})
when is_list(StructDef), is_tuple(Data), length(StructDef) == size(Data) - 1 ->
[StructName | Elems] = tuple_to_list(Data),
{Proto1, ok} = write(Proto0, #protocol_struct_begin{name = StructName}),
{Proto2, ok} = struct_write_loop(Proto1, StructDef, Elems),
{Proto3, ok} = write(Proto2, struct_end),
{Proto3, ok};
write(Proto, {{struct, {Module, StructureName}}, Data})
when is_atom(Module),
is_atom(StructureName),
element(1, Data) =:= StructureName ->
write(Proto, {Module:struct_info(StructureName), Data});
write(_, {{struct, {Module, StructureName}}, Data})
when is_atom(Module),
is_atom(StructureName) ->
erlang:error(struct_unmatched, {{provided, element(1, Data)},
{expected, StructureName}});
write(Proto0, {{list, Type}, Data})
when is_list(Data) ->
{Proto1, ok} = write(Proto0,
#protocol_list_begin{
etype = term_to_typeid(Type),
size = length(Data)
}),
Proto2 = lists:foldl(fun(Elem, ProtoIn) ->
{ProtoOut, ok} = write(ProtoIn, {Type, Elem}),
ProtoOut
end,
Proto1,
Data),
{Proto3, ok} = write(Proto2, list_end),
{Proto3, ok};
write(Proto0, {{map, KeyType, ValType}, Data}) ->
{Proto1, ok} = write(Proto0,
#protocol_map_begin{
ktype = term_to_typeid(KeyType),
vtype = term_to_typeid(ValType),
size = dict:size(Data)
}),
Proto2 = dict:fold(fun(KeyData, ValData, ProtoS0) ->
{ProtoS1, ok} = write(ProtoS0, {KeyType, KeyData}),
{ProtoS2, ok} = write(ProtoS1, {ValType, ValData}),
ProtoS2
end,
Proto1,
Data),
{Proto3, ok} = write(Proto2, map_end),
{Proto3, ok};
write(Proto0, {{set, Type}, Data}) ->
true = sets:is_set(Data),
{Proto1, ok} = write(Proto0,
#protocol_set_begin{
etype = term_to_typeid(Type),
size = sets:size(Data)
}),
Proto2 = sets:fold(fun(Elem, ProtoIn) ->
{ProtoOut, ok} = write(ProtoIn, {Type, Elem}),
ProtoOut
end,
Proto1,
Data),
{Proto3, ok} = write(Proto2, set_end),
{Proto3, ok};
write(Proto = #protocol{module = Module,
data = ModuleData}, Data) ->
{NewData, Result} = Module:write(ModuleData, Data),
{Proto#protocol{data = NewData}, Result}.
struct_write_loop(Proto0, [{Fid, Type} | RestStructDef], [Data | RestData]) ->
NewProto = case Data of
undefined ->
Proto0; % null fields are skipped in response
_ ->
{Proto1, ok} = write(Proto0,
#protocol_field_begin{
type = term_to_typeid(Type),
id = Fid
}),
{Proto2, ok} = write(Proto1, {Type, Data}),
{Proto3, ok} = write(Proto2, field_end),
Proto3
end,
struct_write_loop(NewProto, RestStructDef, RestData);
struct_write_loop(Proto, [], []) ->
write(Proto, field_stop).

View file

@ -0,0 +1,258 @@
%%
%% Licensed to the Apache Software Foundation (ASF) under one
%% or more contributor license agreements. See the NOTICE file
%% distributed with this work for additional information
%% regarding copyright ownership. The ASF licenses this file
%% to you under the Apache License, Version 2.0 (the
%% "License"); you may not use this file except in compliance
%% with the License. You may obtain a copy of the License at
%%
%% http://www.apache.org/licenses/LICENSE-2.0
%%
%% Unless required by applicable law or agreed to in writing,
%% software distributed under the License is distributed on an
%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
%% KIND, either express or implied. See the License for the
%% specific language governing permissions and limitations
%% under the License.
%%
-module(thrift_reconnecting_client).
-behaviour(gen_server).
%% API
-export([ call/3,
get_stats/1,
get_and_reset_stats/1 ]).
-export([ start_link/6 ]).
%% gen_server callbacks
-export([ init/1,
handle_call/3,
handle_cast/2,
handle_info/2,
terminate/2,
code_change/3 ]).
-record( state, { client = nil,
host,
port,
thrift_svc,
thrift_opts,
reconn_min,
reconn_max,
reconn_time = 0,
op_cnt_dict,
op_time_dict } ).
%%====================================================================
%% API
%%====================================================================
%%--------------------------------------------------------------------
%% Function: start_link() -> {ok,Pid} | ignore | {error,Error}
%% Description: Starts the server
%%--------------------------------------------------------------------
start_link( Host, Port,
ThriftSvc, ThriftOpts,
ReconnMin, ReconnMax ) ->
gen_server:start_link( ?MODULE,
[ Host, Port,
ThriftSvc, ThriftOpts,
ReconnMin, ReconnMax ],
[] ).
call( Pid, Op, Args ) ->
gen_server:call( Pid, { call, Op, Args } ).
get_stats( Pid ) ->
gen_server:call( Pid, get_stats ).
get_and_reset_stats( Pid ) ->
gen_server:call( Pid, get_and_reset_stats ).
%%====================================================================
%% gen_server callbacks
%%====================================================================
%%--------------------------------------------------------------------
%% Function: init(Args) -> {ok, State} |
%% {ok, State, Timeout} |
%% ignore |
%% {stop, Reason}
%% Description: Start the server.
%%--------------------------------------------------------------------
init( [ Host, Port, TSvc, TOpts, ReconnMin, ReconnMax ] ) ->
process_flag( trap_exit, true ),
State = #state{ host = Host,
port = Port,
thrift_svc = TSvc,
thrift_opts = TOpts,
reconn_min = ReconnMin,
reconn_max = ReconnMax,
op_cnt_dict = dict:new(),
op_time_dict = dict:new() },
{ ok, try_connect( State ) }.
%%--------------------------------------------------------------------
%% Function: %% handle_call(Request, From, State) -> {reply, Reply, State} |
%% {reply, Reply, State, Timeout} |
%% {noreply, State} |
%% {noreply, State, Timeout} |
%% {stop, Reason, Reply, State} |
%% {stop, Reason, State}
%% Description: Handling call messages
%%--------------------------------------------------------------------
handle_call( { call, Op, _ },
_From,
State = #state{ client = nil } ) ->
{ reply, { error, noconn }, incr_stats( Op, "failfast", 1, State ) };
handle_call( { call, Op, Args },
_From,
State=#state{ client = Client } ) ->
Timer = timer_fun(),
Result = ( catch thrift_client:call( Client, Op, Args) ),
Time = Timer(),
case Result of
{ C, { ok, Reply } } ->
S = incr_stats( Op, "success", Time, State#state{ client = C } ),
{ reply, {ok, Reply }, S };
{ _, { E, Msg } } when E == error; E == exception ->
S = incr_stats( Op, "error", Time, try_connect( State ) ),
{ reply, { E, Msg }, S };
Other ->
S = incr_stats( Op, "error", Time, try_connect( State ) ),
{ reply, Other, S }
end;
handle_call( get_stats,
_From,
State = #state{} ) ->
{ reply, stats( State ), State };
handle_call( get_and_reset_stats,
_From,
State = #state{} ) ->
{ reply, stats( State ), reset_stats( State ) }.
%%--------------------------------------------------------------------
%% Function: handle_cast(Msg, State) -> {noreply, State} |
%% {noreply, State, Timeout} |
%% {stop, Reason, State}
%% Description: Handling cast messages
%%--------------------------------------------------------------------
handle_cast( _Msg, State ) ->
{ noreply, State }.
%%--------------------------------------------------------------------
%% Function: handle_info(Info, State) -> {noreply, State} |
%% {noreply, State, Timeout} |
%% {stop, Reason, State}
%% Description: Handling all non call/cast messages
%%--------------------------------------------------------------------
handle_info( try_connect, State ) ->
{ noreply, try_connect( State ) };
handle_info( _Info, State ) ->
{ noreply, State }.
%%--------------------------------------------------------------------
%% Function: terminate(Reason, State) -> void()
%% Description: This function is called by a gen_server when it is about to
%% terminate. It should be the opposite of Module:init/1 and do any necessary
%% cleaning up. When it returns, the gen_server terminates with Reason.
%% The return value is ignored.
%%--------------------------------------------------------------------
terminate( _Reason, #state{ client = Client } ) ->
thrift_client:close( Client ),
ok.
%%--------------------------------------------------------------------
%% Func: code_change(OldVsn, State, Extra) -> {ok, NewState}
%% Description: Convert process state when code is changed
%%--------------------------------------------------------------------
code_change( _OldVsn, State, _Extra ) ->
{ ok, State }.
%%--------------------------------------------------------------------
%%% Internal functions
%%--------------------------------------------------------------------
try_connect( State = #state{ client = OldClient,
host = Host,
port = Port,
thrift_svc = TSvc,
thrift_opts = TOpts } ) ->
case OldClient of
nil -> ok;
_ -> ( catch thrift_client:close( OldClient ) )
end,
case catch thrift_client_util:new( Host, Port, TSvc, TOpts ) of
{ ok, Client } ->
State#state{ client = Client, reconn_time = 0 };
{ E, Msg } when E == error; E == exception ->
ReconnTime = reconn_time( State ),
error_logger:error_msg( "[~w] ~w connect failed (~w), trying again in ~w ms~n",
[ self(), TSvc, Msg, ReconnTime ] ),
erlang:send_after( ReconnTime, self(), try_connect ),
State#state{ client = nil, reconn_time = ReconnTime }
end.
reconn_time( #state{ reconn_min = ReconnMin, reconn_time = 0 } ) ->
ReconnMin;
reconn_time( #state{ reconn_max = ReconnMax, reconn_time = ReconnMax } ) ->
ReconnMax;
reconn_time( #state{ reconn_max = ReconnMax, reconn_time = R } ) ->
Backoff = 2 * R,
case Backoff > ReconnMax of
true -> ReconnMax;
false -> Backoff
end.
-ifdef(time_correction).
timer_fun() ->
T1 = erlang:monotonic_time(),
fun() ->
T2 = erlang:monotonic_time(),
erlang:convert_time_unit(T2 - T1, native, micro_seconds)
end.
-else.
timer_fun() ->
T1 = erlang:now(),
fun() ->
T2 = erlang:now(),
timer:now_diff(T2, T1)
end.
-endif.
incr_stats( Op, Result, Time,
State = #state{ op_cnt_dict = OpCntDict,
op_time_dict = OpTimeDict } ) ->
Key = lists:flatten( [ atom_to_list( Op ), [ "_" | Result ] ] ),
State#state{ op_cnt_dict = dict:update_counter( Key, 1, OpCntDict ),
op_time_dict = dict:update_counter( Key, Time, OpTimeDict ) }.
stats( #state{ thrift_svc = TSvc,
op_cnt_dict = OpCntDict,
op_time_dict = OpTimeDict } ) ->
Svc = atom_to_list(TSvc),
F = fun( Key, Count, Stats ) ->
Name = lists:flatten( [ Svc, [ "_" | Key ] ] ),
Micros = dict:fetch( Key, OpTimeDict ),
[ { Name, Count, Micros } | Stats ]
end,
dict:fold( F, [], OpCntDict ).
reset_stats( State = #state{} ) ->
State#state{ op_cnt_dict = dict:new(), op_time_dict = dict:new() }.

View file

@ -0,0 +1,183 @@
%%
%% Licensed to the Apache Software Foundation (ASF) under one
%% or more contributor license agreements. See the NOTICE file
%% distributed with this work for additional information
%% regarding copyright ownership. The ASF licenses this file
%% to you under the Apache License, Version 2.0 (the
%% "License"); you may not use this file except in compliance
%% with the License. You may obtain a copy of the License at
%%
%% http://www.apache.org/licenses/LICENSE-2.0
%%
%% Unless required by applicable law or agreed to in writing,
%% software distributed under the License is distributed on an
%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
%% KIND, either express or implied. See the License for the
%% specific language governing permissions and limitations
%% under the License.
%%
-module(thrift_server).
-behaviour(gen_server).
%% API
-export([start_link/3, stop/1, take_socket/2]).
%% gen_server callbacks
-export([init/1, handle_call/3, handle_cast/2, handle_info/2,
terminate/2, code_change/3]).
-define(SERVER, ?MODULE).
-record(state, {listen_socket, acceptor_ref, service, handler}).
%%====================================================================
%% API
%%====================================================================
%%--------------------------------------------------------------------
%% Function: start_link() -> {ok,Pid} | ignore | {error,Error}
%% Description: Starts the server
%%--------------------------------------------------------------------
start_link(Port, Service, HandlerModule) when is_integer(Port), is_atom(HandlerModule) ->
gen_server:start_link({local, ?SERVER}, ?MODULE, {Port, Service, HandlerModule}, []).
%%--------------------------------------------------------------------
%% Function: stop(Pid) -> ok, {error, Reason}
%% Description: Stops the server.
%%--------------------------------------------------------------------
stop(Pid) when is_pid(Pid) ->
gen_server:call(Pid, stop).
take_socket(Server, Socket) ->
gen_server:call(Server, {take_socket, Socket}).
%%====================================================================
%% gen_server callbacks
%%====================================================================
%%--------------------------------------------------------------------
%% Function: init(Args) -> {ok, State} |
%% {ok, State, Timeout} |
%% ignore |
%% {stop, Reason}
%% Description: Initiates the server
%%--------------------------------------------------------------------
init({Port, Service, Handler}) ->
{ok, Socket} = gen_tcp:listen(Port,
[binary,
{packet, 0},
{active, false},
{nodelay, true},
{reuseaddr, true}]),
{ok, Ref} = prim_inet:async_accept(Socket, -1),
{ok, #state{listen_socket = Socket,
acceptor_ref = Ref,
service = Service,
handler = Handler}}.
%%--------------------------------------------------------------------
%% Function: %% handle_call(Request, From, State) -> {reply, Reply, State} |
%% {reply, Reply, State, Timeout} |
%% {noreply, State} |
%% {noreply, State, Timeout} |
%% {stop, Reason, Reply, State} |
%% {stop, Reason, State}
%% Description: Handling call messages
%%--------------------------------------------------------------------
handle_call(stop, _From, State) ->
{stop, stopped, ok, State};
handle_call({take_socket, Socket}, {FromPid, _Tag}, State) ->
Result = gen_tcp:controlling_process(Socket, FromPid),
{reply, Result, State}.
%%--------------------------------------------------------------------
%% Function: handle_cast(Msg, State) -> {noreply, State} |
%% {noreply, State, Timeout} |
%% {stop, Reason, State}
%% Description: Handling cast messages
%%--------------------------------------------------------------------
handle_cast(_Msg, State) ->
{noreply, State}.
%%--------------------------------------------------------------------
%% Function: handle_info(Info, State) -> {noreply, State} |
%% {noreply, State, Timeout} |
%% {stop, Reason, State}
%% Description: Handling all non call/cast messages
%%--------------------------------------------------------------------
handle_info({inet_async, ListenSocket, Ref, {ok, ClientSocket}},
State = #state{listen_socket = ListenSocket,
acceptor_ref = Ref,
service = Service,
handler = Handler}) ->
case set_sockopt(ListenSocket, ClientSocket) of
ok ->
%% New client connected - start processor
start_processor(ClientSocket, Service, Handler),
{ok, NewRef} = prim_inet:async_accept(ListenSocket, -1),
{noreply, State#state{acceptor_ref = NewRef}};
{error, Reason} ->
error_logger:error_msg("Couldn't set socket opts: ~p~n",
[Reason]),
{stop, Reason, State}
end;
handle_info({inet_async, _ListenSocket, _Ref, Error}, State) ->
error_logger:error_msg("Error in acceptor: ~p~n", [Error]),
{stop, Error, State};
handle_info(_Info, State) ->
{noreply, State}.
%%--------------------------------------------------------------------
%% Function: terminate(Reason, State) -> void()
%% Description: This function is called by a gen_server when it is about to
%% terminate. It should be the opposite of Module:init/1 and do any necessary
%% cleaning up. When it returns, the gen_server terminates with Reason.
%% The return value is ignored.
%%--------------------------------------------------------------------
terminate(_Reason, _State) ->
ok.
%%--------------------------------------------------------------------
%% Func: code_change(OldVsn, State, Extra) -> {ok, NewState}
%% Description: Convert process state when code is changed
%%--------------------------------------------------------------------
code_change(_OldVsn, State, _Extra) ->
{ok, State}.
%%--------------------------------------------------------------------
%%% Internal functions
%%--------------------------------------------------------------------
set_sockopt(ListenSocket, ClientSocket) ->
true = inet_db:register_socket(ClientSocket, inet_tcp),
case prim_inet:getopts(ListenSocket,
[active, nodelay, keepalive, delay_send, priority, tos]) of
{ok, Opts} ->
case prim_inet:setopts(ClientSocket, Opts) of
ok -> ok;
Error -> gen_tcp:close(ClientSocket),
Error
end;
Error ->
gen_tcp:close(ClientSocket),
Error
end.
start_processor(Socket, Service, Handler) ->
Server = self(),
ProtoGen = fun() ->
% Become the controlling process
ok = take_socket(Server, Socket),
{ok, SocketTransport} = thrift_socket_transport:new(Socket),
{ok, BufferedTransport} = thrift_buffered_transport:new(SocketTransport),
{ok, Protocol} = thrift_binary_protocol:new(BufferedTransport),
{ok, Protocol}
end,
spawn(thrift_processor, init, [{Server, ProtoGen, Service, Handler}]).

View file

@ -0,0 +1,25 @@
%%
%% Licensed to the Apache Software Foundation (ASF) under one
%% or more contributor license agreements. See the NOTICE file
%% distributed with this work for additional information
%% regarding copyright ownership. The ASF licenses this file
%% to you under the Apache License, Version 2.0 (the
%% "License"); you may not use this file except in compliance
%% with the License. You may obtain a copy of the License at
%%
%% http://www.apache.org/licenses/LICENSE-2.0
%%
%% Unless required by applicable law or agreed to in writing,
%% software distributed under the License is distributed on an
%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
%% KIND, either express or implied. See the License for the
%% specific language governing permissions and limitations
%% under the License.
%%
-module(thrift_service).
-export([behaviour_info/1]).
behaviour_info(callbacks) ->
[{function_info, 2}].

View file

@ -0,0 +1,320 @@
%%
%% Licensed to the Apache Software Foundation (ASF) under one
%% or more contributor license agreements. See the NOTICE file
%% distributed with this work for additional information
%% regarding copyright ownership. The ASF licenses this file
%% to you under the Apache License, Version 2.0 (the
%% "License"); you may not use this file except in compliance
%% with the License. You may obtain a copy of the License at
%%
%% http://www.apache.org/licenses/LICENSE-2.0
%%
%% Unless required by applicable law or agreed to in writing,
%% software distributed under the License is distributed on an
%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
%% KIND, either express or implied. See the License for the
%% specific language governing permissions and limitations
%% under the License.
%%
-module(thrift_socket_server).
-behaviour(gen_server).
-include ("thrift_constants.hrl").
-ifdef(TEST).
-compile(export_all).
-export_records([thrift_socket_server]).
-else.
-export([start/1, stop/1]).
-export([init/1, handle_call/3, handle_cast/2, terminate/2, code_change/3,
handle_info/2]).
-export([acceptor_loop/1]).
-endif.
-record(thrift_socket_server,
{port,
service,
handler,
name,
max=2048,
ip=any,
listen=null,
acceptor=null,
socket_opts=[{recv_timeout, 500}],
protocol=binary,
framed=false,
ssltransport=false,
ssloptions=[]
}).
start(State=#thrift_socket_server{}) ->
start_server(State);
start(Options) ->
start(parse_options(Options)).
stop(Name) when is_atom(Name) ->
gen_server:cast(Name, stop);
stop(Pid) when is_pid(Pid) ->
gen_server:cast(Pid, stop);
stop({local, Name}) ->
stop(Name);
stop({global, Name}) ->
stop(Name);
stop(Options) ->
State = parse_options(Options),
stop(State#thrift_socket_server.name).
%% Internal API
parse_options(Options) ->
parse_options(Options, #thrift_socket_server{}).
parse_options([], State) ->
State;
parse_options([{name, L} | Rest], State) when is_list(L) ->
Name = {local, list_to_atom(L)},
parse_options(Rest, State#thrift_socket_server{name=Name});
parse_options([{name, A} | Rest], State) when is_atom(A) ->
Name = {local, A},
parse_options(Rest, State#thrift_socket_server{name=Name});
parse_options([{name, Name} | Rest], State) ->
parse_options(Rest, State#thrift_socket_server{name=Name});
parse_options([{port, L} | Rest], State) when is_list(L) ->
Port = list_to_integer(L),
parse_options(Rest, State#thrift_socket_server{port=Port});
parse_options([{port, Port} | Rest], State) ->
parse_options(Rest, State#thrift_socket_server{port=Port});
parse_options([{ip, Ip} | Rest], State) ->
ParsedIp = case Ip of
any ->
any;
Ip when is_tuple(Ip) ->
Ip;
Ip when is_list(Ip) ->
{ok, IpTuple} = inet_parse:address(Ip),
IpTuple
end,
parse_options(Rest, State#thrift_socket_server{ip=ParsedIp});
parse_options([{socket_opts, L} | Rest], State) when is_list(L), length(L) > 0 ->
parse_options(Rest, State#thrift_socket_server{socket_opts=L});
parse_options([{handler, []} | _Rest], _State) ->
throw("At least an error handler must be defined.");
parse_options([{handler, ServiceHandlerPropertyList} | Rest], State) when is_list(ServiceHandlerPropertyList) ->
ServiceHandlerMap =
case State#thrift_socket_server.handler of
undefined ->
lists:foldl(
fun ({ServiceName, ServiceHandler}, Acc) when is_list(ServiceName), is_atom(ServiceHandler) ->
thrift_multiplexed_map_wrapper:store(ServiceName, ServiceHandler, Acc);
(_, _Acc) ->
throw("The handler option is not properly configured for multiplexed services. It should be a kind of [{\"error_handler\", Module::atom()}, {SericeName::list(), Module::atom()}, ...]")
end, thrift_multiplexed_map_wrapper:new(), ServiceHandlerPropertyList);
_ -> throw("Error while parsing the handler option.")
end,
case thrift_multiplexed_map_wrapper:find(?MULTIPLEXED_ERROR_HANDLER_KEY, ServiceHandlerMap) of
{ok, _ErrorHandler} -> parse_options(Rest, State#thrift_socket_server{handler=ServiceHandlerMap});
error -> throw("The handler option is not properly configured for multiplexed services. It should be a kind of [{\"error_handler\", Module::atom()}, {SericeName::list(), Module::atom()}, ...]")
end;
parse_options([{handler, Handler} | Rest], State) when State#thrift_socket_server.handler == undefined, is_atom(Handler) ->
parse_options(Rest, State#thrift_socket_server{handler=Handler});
parse_options([{service, []} | _Rest], _State) ->
throw("At least one service module must be defined.");
parse_options([{service, ServiceModulePropertyList} | Rest], State) when is_list(ServiceModulePropertyList) ->
ServiceModuleMap =
case State#thrift_socket_server.service of
undefined ->
lists:foldl(
fun ({ServiceName, ServiceModule}, Acc) when is_list(ServiceName), is_atom(ServiceModule) ->
thrift_multiplexed_map_wrapper:store(ServiceName, ServiceModule, Acc);
(_, _Acc) ->
throw("The service option is not properly configured for multiplexed services. It should be a kind of [{SericeName::list(), ServiceModule::atom()}, ...]")
end, thrift_multiplexed_map_wrapper:new(), ServiceModulePropertyList);
_ -> throw("Error while parsing the service option.")
end,
parse_options(Rest, State#thrift_socket_server{service=ServiceModuleMap});
parse_options([{service, Service} | Rest], State) when State#thrift_socket_server.service == undefined, is_atom(Service) ->
parse_options(Rest, State#thrift_socket_server{service=Service});
parse_options([{max, Max} | Rest], State) ->
MaxInt = case Max of
Max when is_list(Max) ->
list_to_integer(Max);
Max when is_integer(Max) ->
Max
end,
parse_options(Rest, State#thrift_socket_server{max=MaxInt});
parse_options([{protocol, Proto} | Rest], State) when is_atom(Proto) ->
parse_options(Rest, State#thrift_socket_server{protocol=Proto});
parse_options([{framed, Framed} | Rest], State) when is_boolean(Framed) ->
parse_options(Rest, State#thrift_socket_server{framed=Framed});
parse_options([{ssltransport, SSLTransport} | Rest], State) when is_boolean(SSLTransport) ->
parse_options(Rest, State#thrift_socket_server{ssltransport=SSLTransport});
parse_options([{ssloptions, SSLOptions} | Rest], State) when is_list(SSLOptions) ->
parse_options(Rest, State#thrift_socket_server{ssloptions=SSLOptions}).
start_server(State=#thrift_socket_server{name=Name}) ->
case Name of
undefined ->
gen_server:start_link(?MODULE, State, []);
_ ->
gen_server:start_link(Name, ?MODULE, State, [])
end.
init(State=#thrift_socket_server{ip=Ip, port=Port}) ->
process_flag(trap_exit, true),
BaseOpts = [binary,
{reuseaddr, true},
{packet, 0},
{backlog, 4096},
{recbuf, 8192},
{active, false}],
Opts = case Ip of
any ->
BaseOpts;
Ip ->
[{ip, Ip} | BaseOpts]
end,
case gen_tcp_listen(Port, Opts, State) of
{stop, eacces} ->
%% fdsrv module allows another shot to bind
%% ports which require root access
case Port < 1024 of
true ->
case fdsrv:start() of
{ok, _} ->
case fdsrv:bind_socket(tcp, Port) of
{ok, Fd} ->
gen_tcp_listen(Port, [{fd, Fd} | Opts], State);
_ ->
{stop, fdsrv_bind_failed}
end;
_ ->
{stop, fdsrv_start_failed}
end;
false ->
{stop, eacces}
end;
Other ->
error_logger:info_msg("thrift service listening on port ~p", [Port]),
Other
end.
gen_tcp_listen(Port, Opts, State) ->
case gen_tcp:listen(Port, Opts) of
{ok, Listen} ->
{ok, ListenPort} = inet:port(Listen),
{ok, new_acceptor(State#thrift_socket_server{listen=Listen,
port=ListenPort})};
{error, Reason} ->
{stop, Reason}
end.
new_acceptor(State=#thrift_socket_server{max=0}) ->
error_logger:error_msg("Not accepting new connections"),
State#thrift_socket_server{acceptor=null};
new_acceptor(State=#thrift_socket_server{listen=Listen,
service=Service, handler=Handler,
socket_opts=Opts, framed=Framed, protocol=Proto,
ssltransport=SslTransport, ssloptions=SslOptions
}) ->
Pid = proc_lib:spawn_link(?MODULE, acceptor_loop,
[{self(), Listen, Service, Handler, Opts, Framed, SslTransport, SslOptions, Proto}]),
State#thrift_socket_server{acceptor=Pid}.
acceptor_loop({Server, Listen, Service, Handler, SocketOpts, Framed, SslTransport, SslOptions, Proto})
when is_pid(Server), is_list(SocketOpts) ->
case catch gen_tcp:accept(Listen) of % infinite timeout
{ok, Socket} ->
gen_server:cast(Server, {accepted, self()}),
ProtoGen = fun() ->
{ok, SocketTransport} = case SslTransport of
true -> thrift_sslsocket_transport:new(Socket, SocketOpts, SslOptions);
false -> thrift_socket_transport:new(Socket, SocketOpts)
end,
{ok, Transport} = case Framed of
true -> thrift_framed_transport:new(SocketTransport);
false -> thrift_buffered_transport:new(SocketTransport)
end,
{ok, Protocol} = case Proto of
compact -> thrift_compact_protocol:new(Transport);
json -> thrift_json_protocol:new(Transport);
_ -> thrift_binary_protocol:new(Transport)
end,
{ok, Protocol}
end,
thrift_processor:init({Server, ProtoGen, Service, Handler});
{error, closed} ->
exit({error, closed});
Other ->
error_logger:error_report(
[{application, thrift},
"Accept failed error",
lists:flatten(io_lib:format("~p", [Other]))]),
exit({error, accept_failed})
end.
handle_call({get, port}, _From, State=#thrift_socket_server{port=Port}) ->
{reply, Port, State};
handle_call(_Message, _From, State) ->
Res = error,
{reply, Res, State}.
handle_cast({accepted, Pid},
State=#thrift_socket_server{acceptor=Pid, max=Max}) ->
% io:format("accepted ~p~n", [Pid]),
State1 = State#thrift_socket_server{max=Max - 1},
{noreply, new_acceptor(State1)};
handle_cast(stop, State) ->
{stop, normal, State}.
terminate(Reason, #thrift_socket_server{listen=Listen, port=Port}) ->
gen_tcp:close(Listen),
{backtrace, Bt} = erlang:process_info(self(), backtrace),
error_logger:error_report({?MODULE, ?LINE,
{child_error, Reason, Bt}}),
case Port < 1024 of
true ->
catch fdsrv:stop(),
ok;
false ->
ok
end.
code_change(_OldVsn, State, _Extra) ->
State.
handle_info({'EXIT', Pid, normal},
State=#thrift_socket_server{acceptor=Pid}) ->
{noreply, new_acceptor(State)};
handle_info({'EXIT', Pid, Reason},
State=#thrift_socket_server{acceptor=Pid}) ->
error_logger:error_report({?MODULE, ?LINE,
{acceptor_error, Reason}}),
timer:sleep(100),
{noreply, new_acceptor(State)};
handle_info({'EXIT', _LoopPid, Reason},
State=#thrift_socket_server{acceptor=Pid, max=Max}) ->
case Reason of
normal -> ok;
shutdown -> ok;
_ -> error_logger:error_report({?MODULE, ?LINE,
{child_error, Reason, erlang:get_stacktrace()}})
end,
State1 = State#thrift_socket_server{max=Max + 1},
State2 = case Pid of
null -> new_acceptor(State1);
_ -> State1
end,
{noreply, State2};
handle_info(Info, State) ->
error_logger:info_report([{'INFO', Info}, {'State', State}]),
{noreply, State}.

View file

@ -0,0 +1,176 @@
%%
%% Licensed to the Apache Software Foundation (ASF) under one
%% or more contributor license agreements. See the NOTICE file
%% distributed with this work for additional information
%% regarding copyright ownership. The ASF licenses this file
%% to you under the Apache License, Version 2.0 (the
%% "License"); you may not use this file except in compliance
%% with the License. You may obtain a copy of the License at
%%
%% http://www.apache.org/licenses/LICENSE-2.0
%%
%% Unless required by applicable law or agreed to in writing,
%% software distributed under the License is distributed on an
%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
%% KIND, either express or implied. See the License for the
%% specific language governing permissions and limitations
%% under the License.
%%
-module(thrift_socket_transport).
-behaviour(thrift_transport).
%% constructors
-export([new/1, new/2]).
%% transport callbacks
-export([read/2, read_exact/2, write/2, flush/1, close/1]).
%% legacy api
-export([new_transport_factory/3]).
-record(t_socket, {
socket,
recv_timeout=60000,
buffer = []
}).
-type state() :: #t_socket{}.
-spec new(Socket::any()) ->
thrift_transport:t_transport().
new(Socket) -> new(Socket, []).
-spec new(Socket::any(), Opts::list()) ->
thrift_transport:t_transport().
new(Socket, Opts) when is_list(Opts) ->
State = parse_opts(Opts, #t_socket{socket = Socket}),
thrift_transport:new(?MODULE, State).
parse_opts([{recv_timeout, Timeout}|Rest], State)
when is_integer(Timeout), Timeout > 0 ->
parse_opts(Rest, State#t_socket{recv_timeout = Timeout});
parse_opts([{recv_timeout, infinity}|Rest], State) ->
parse_opts(Rest, State#t_socket{recv_timeout = infinity});
parse_opts([], State) ->
State.
-include("thrift_transport_behaviour.hrl").
read(State = #t_socket{buffer = Buf}, Len)
when is_integer(Len), Len >= 0 ->
Binary = iolist_to_binary(Buf),
case iolist_size(Binary) of
X when X >= Len ->
{Result, Remaining} = split_binary(Binary, Len),
{State#t_socket{buffer = Remaining}, {ok, Result}};
_ -> recv(State, Len)
end.
recv(State = #t_socket{socket = Socket, buffer = Buf}, Len) ->
case gen_tcp:recv(Socket, 0, State#t_socket.recv_timeout) of
{error, Error} ->
gen_tcp:close(Socket),
{State, {error, Error}};
{ok, Data} ->
Binary = iolist_to_binary([Buf, Data]),
Give = min(iolist_size(Binary), Len),
{Result, Remaining} = split_binary(Binary, Give),
{State#t_socket{buffer = Remaining}, {ok, Result}}
end.
read_exact(State = #t_socket{buffer = Buf}, Len)
when is_integer(Len), Len >= 0 ->
Binary = iolist_to_binary(Buf),
case iolist_size(Binary) of
X when X >= Len -> read(State, Len);
X ->
case gen_tcp:recv(State#t_socket.socket, Len - X, State#t_socket.recv_timeout) of
{error, Error} ->
gen_tcp:close(State#t_socket.socket),
{State, {error, Error}};
{ok, Data} ->
{State#t_socket{buffer = []}, {ok, <<Binary/binary, Data/binary>>}}
end
end.
write(State = #t_socket{socket = Socket}, Data) ->
case gen_tcp:send(Socket, Data) of
{error, Error} ->
gen_tcp:close(Socket),
{State, {error, Error}};
ok -> {State, ok}
end.
flush(State) ->
{State#t_socket{buffer = []}, ok}.
close(State = #t_socket{socket = Socket}) ->
{State, gen_tcp:close(Socket)}.
%% legacy api. left for compatibility
%% The following "local" record is filled in by parse_factory_options/2
%% below. These options can be passed to new_protocol_factory/3 in a
%% proplists-style option list. They're parsed like this so it is an O(n)
%% operation instead of O(n^2)
-record(factory_opts, {
connect_timeout = infinity,
sockopts = [],
framed = false
}).
parse_factory_options([], FactoryOpts, TransOpts) -> {FactoryOpts, TransOpts};
parse_factory_options([{framed, Bool}|Rest], FactoryOpts, TransOpts)
when is_boolean(Bool) ->
parse_factory_options(Rest, FactoryOpts#factory_opts{framed = Bool}, TransOpts);
parse_factory_options([{sockopts, OptList}|Rest], FactoryOpts, TransOpts)
when is_list(OptList) ->
parse_factory_options(Rest, FactoryOpts#factory_opts{sockopts = OptList}, TransOpts);
parse_factory_options([{connect_timeout, TO}|Rest], FactoryOpts, TransOpts)
when TO =:= infinity; is_integer(TO) ->
parse_factory_options(Rest, FactoryOpts#factory_opts{connect_timeout = TO}, TransOpts);
parse_factory_options([{recv_timeout, TO}|Rest], FactoryOpts, TransOpts)
when TO =:= infinity; is_integer(TO) ->
parse_factory_options(Rest, FactoryOpts, [{recv_timeout, TO}] ++ TransOpts).
%% Generates a "transport factory" function - a fun which returns a thrift_transport()
%% instance.
%% State can be passed into a protocol factory to generate a connection to a
%% thrift server over a socket.
new_transport_factory(Host, Port, Options) ->
{FactoryOpts, TransOpts} = parse_factory_options(Options, #factory_opts{}, []),
{ok, fun() -> SockOpts = [binary,
{packet, 0},
{active, false},
{nodelay, true}|FactoryOpts#factory_opts.sockopts
],
case catch gen_tcp:connect(
Host,
Port,
SockOpts,
FactoryOpts#factory_opts.connect_timeout
) of
{ok, Sock} ->
{ok, Transport} = thrift_socket_transport:new(Sock, TransOpts),
{ok, BufTransport} = case FactoryOpts#factory_opts.framed of
true -> thrift_framed_transport:new(Transport);
false -> thrift_buffered_transport:new(Transport)
end,
{ok, BufTransport};
Error -> Error
end
end}.

View file

@ -0,0 +1,147 @@
%%
%% Licensed to the Apache Software Foundation (ASF) under one
%% or more contributor license agreements. See the NOTICE file
%% distributed with this work for additional information
%% regarding copyright ownership. The ASF licenses this file
%% to you under the Apache License, Version 2.0 (the
%% "License"); you may not use this file except in compliance
%% with the License. You may obtain a copy of the License at
%%
%% http://www.apache.org/licenses/LICENSE-2.0
%%
%% Unless required by applicable law or agreed to in writing,
%% software distributed under the License is distributed on an
%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
%% KIND, either express or implied. See the License for the
%% specific language governing permissions and limitations
%% under the License.
%%
-module(thrift_sslsocket_transport).
-include("thrift_transport_behaviour.hrl").
-behaviour(thrift_transport).
-export([new/3,
write/2, read/2, flush/1, close/1,
new_transport_factory/3]).
%% Export only for the transport factory
-export([new/2]).
-record(data, {socket,
recv_timeout=infinity}).
-type state() :: #data{}.
%% The following "local" record is filled in by parse_factory_options/2
%% below. These options can be passed to new_protocol_factory/3 in a
%% proplists-style option list. They're parsed like this so it is an O(n)
%% operation instead of O(n^2)
-record(factory_opts, {connect_timeout = infinity,
sockopts = [],
framed = false,
ssloptions = []}).
parse_factory_options([], Opts) ->
Opts;
parse_factory_options([{framed, Bool} | Rest], Opts) when is_boolean(Bool) ->
parse_factory_options(Rest, Opts#factory_opts{framed=Bool});
parse_factory_options([{sockopts, OptList} | Rest], Opts) when is_list(OptList) ->
parse_factory_options(Rest, Opts#factory_opts{sockopts=OptList});
parse_factory_options([{connect_timeout, TO} | Rest], Opts) when TO =:= infinity; is_integer(TO) ->
parse_factory_options(Rest, Opts#factory_opts{connect_timeout=TO});
parse_factory_options([{ssloptions, SslOptions} | Rest], Opts) when is_list(SslOptions) ->
parse_factory_options(Rest, Opts#factory_opts{ssloptions=SslOptions}).
new(Socket, SockOpts, SslOptions) when is_list(SockOpts), is_list(SslOptions) ->
inet:setopts(Socket, [{active, false}]), %% => prevent the ssl handshake messages get lost
%% upgrade to an ssl socket
case catch ssl:ssl_accept(Socket, SslOptions) of % infinite timeout
{ok, SslSocket} ->
new(SslSocket, SockOpts);
{error, Reason} ->
exit({error, Reason});
Other ->
error_logger:error_report(
[{application, thrift},
"SSL accept failed error",
lists:flatten(io_lib:format("~p", [Other]))]),
exit({error, ssl_accept_failed})
end.
new(SslSocket, SockOpts) ->
State =
case lists:keysearch(recv_timeout, 1, SockOpts) of
{value, {recv_timeout, Timeout}}
when is_integer(Timeout), Timeout > 0 ->
#data{socket=SslSocket, recv_timeout=Timeout};
_ ->
#data{socket=SslSocket}
end,
thrift_transport:new(?MODULE, State).
%% Data :: iolist()
write(This = #data{socket = Socket}, Data) ->
{This, ssl:send(Socket, Data)}.
read(This = #data{socket=Socket, recv_timeout=Timeout}, Len)
when is_integer(Len), Len >= 0 ->
case ssl:recv(Socket, Len, Timeout) of
Err = {error, timeout} ->
error_logger:info_msg("read timeout: peer conn ~p", [inet:peername(Socket)]),
ssl:close(Socket),
{This, Err};
Data ->
{This, Data}
end.
%% We can't really flush - everything is flushed when we write
flush(This) ->
{This, ok}.
close(This = #data{socket = Socket}) ->
{This, ssl:close(Socket)}.
%%%% FACTORY GENERATION %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%
%% Generates a "transport factory" function - a fun which returns a thrift_transport()
%% instance.
%% This can be passed into a protocol factory to generate a connection to a
%% thrift server over a socket.
%%
new_transport_factory(Host, Port, Options) ->
ParsedOpts = parse_factory_options(Options, #factory_opts{}),
F = fun() ->
SockOpts = [binary,
{packet, 0},
{active, false},
{nodelay, true} |
ParsedOpts#factory_opts.sockopts],
case catch gen_tcp:connect(Host, Port, SockOpts,
ParsedOpts#factory_opts.connect_timeout) of
{ok, Sock} ->
SslSock = case catch ssl:connect(Sock, ParsedOpts#factory_opts.ssloptions,
ParsedOpts#factory_opts.connect_timeout) of
{ok, SslSocket} ->
SslSocket;
Other ->
error_logger:info_msg("error while connecting over ssl - reason: ~p~n", [Other]),
catch gen_tcp:close(Sock),
exit(error)
end,
{ok, Transport} = thrift_sslsocket_transport:new(SslSock, SockOpts),
{ok, BufTransport} =
case ParsedOpts#factory_opts.framed of
true -> thrift_framed_transport:new(Transport);
false -> thrift_buffered_transport:new(Transport)
end,
{ok, BufTransport};
Error ->
Error
end
end,
{ok, F}.

View file

@ -0,0 +1,128 @@
%%
%% Licensed to the Apache Software Foundation (ASF) under one
%% or more contributor license agreements. See the NOTICE file
%% distributed with this work for additional information
%% regarding copyright ownership. The ASF licenses this file
%% to you under the Apache License, Version 2.0 (the
%% "License"); you may not use this file except in compliance
%% with the License. You may obtain a copy of the License at
%%
%% http://www.apache.org/licenses/LICENSE-2.0
%%
%% Unless required by applicable law or agreed to in writing,
%% software distributed under the License is distributed on an
%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
%% KIND, either express or implied. See the License for the
%% specific language governing permissions and limitations
%% under the License.
%%
-module(thrift_transport).
-export([behaviour_info/1]).
%% constructors
-export([new/1, new/2]).
%% transport callbacks
-export([read/2, read_exact/2, write/2, flush/1, close/1]).
-export_type([t_transport/0]).
behaviour_info(callbacks) ->
[{read, 2}, {write, 2}, {flush, 1}, {close, 1}].
-record(t_transport, {
module,
state
}).
-type state() :: #t_transport{}.
-type t_transport() :: #t_transport{}.
-ifdef(transport_wrapper_module).
-define(debug_wrap(Transport),
case Transport#t_transport.module of
?transport_wrapper_module -> Transport;
_Else ->
{ok, Result} = ?transport_wrapper_module:new(Transport),
Result
end
).
-else.
-define(debug_wrap(Transport), Transport).
-endif.
-type wrappable() ::
binary() |
list() |
{membuffer, binary() | list()} |
{tcp, port()} |
{tcp, port(), list()} |
{file, file:io_device()} |
{file, file:io_device(), list()} |
{file, file:filename()} |
{file, file:filename(), list()}.
-spec new(wrappable()) -> {ok, #t_transport{}}.
new({membuffer, Membuffer}) when is_binary(Membuffer); is_list(Membuffer) ->
thrift_membuffer_transport:new(Membuffer);
new({membuffer, Membuffer, []}) when is_binary(Membuffer); is_list(Membuffer) ->
thrift_membuffer_transport:new(Membuffer);
new({tcp, Socket}) when is_port(Socket) ->
new({tcp, Socket, []});
new({tcp, Socket, Opts}) when is_port(Socket) ->
thrift_socket_transport:new(Socket, Opts);
new({file, Filename}) when is_list(Filename); is_binary(Filename) ->
new({file, Filename, []});
new({file, Filename, Opts}) when is_list(Filename); is_binary(Filename) ->
{ok, File} = file:open(Filename, [raw, binary]),
new({file, File, Opts});
new({file, File, Opts}) ->
thrift_file_transport:new(File, Opts).
-spec new(Module::module(), State::any()) -> {ok, #t_transport{}}.
new(Module, State) when is_atom(Module) ->
{ok, ?debug_wrap(#t_transport{module = Module, state = State})}.
-include("thrift_transport_behaviour.hrl").
read(Transport = #t_transport{module = Module}, Len)
when is_integer(Len), Len >= 0 ->
{NewState, Result} = Module:read(Transport#t_transport.state, Len),
{Transport#t_transport{state = NewState}, Result}.
read_exact(Transport = #t_transport{module = Module}, Len)
when is_integer(Len), Len >= 0 ->
case lists:keyfind(read_exact, 1, Module:module_info(exports)) of
{read_exact, 2} ->
io:fwrite("HAS EXACT"),
{NewState, Result} = Module:read_exact(Transport#t_transport.state, Len),
{Transport#t_transport{state = NewState}, Result};
_ ->
io:fwrite("~p NO EXACT", [Module]),
read(Transport, Len)
end.
write(Transport = #t_transport{module = Module}, Data) ->
{NewState, Result} = Module:write(Transport#t_transport.state, Data),
{Transport#t_transport{state = NewState}, Result}.
flush(Transport = #t_transport{module = Module}) ->
{NewState, Result} = Module:flush(Transport#t_transport.state),
{Transport#t_transport{state = NewState}, Result}.
close(Transport = #t_transport{module = Module}) ->
{NewState, Result} = Module:close(Transport#t_transport.state),
{Transport#t_transport{state = NewState}, Result}.

View file

@ -0,0 +1,117 @@
%%
%% Licensed to the Apache Software Foundation (ASF) under one
%% or more contributor license agreements. See the NOTICE file
%% distributed with this work for additional information
%% regarding copyright ownership. The ASF licenses this file
%% to you under the Apache License, Version 2.0 (the
%% "License"); you may not use this file except in compliance
%% with the License. You may obtain a copy of the License at
%%
%% http://www.apache.org/licenses/LICENSE-2.0
%%
%% Unless required by applicable law or agreed to in writing,
%% software distributed under the License is distributed on an
%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
%% KIND, either express or implied. See the License for the
%% specific language governing permissions and limitations
%% under the License.
%%
-module(thrift_transport_state_test).
-behaviour(gen_server).
-behaviour(thrift_transport).
%% API
-export([new/1]).
%% gen_server callbacks
-export([init/1, handle_call/3, handle_cast/2, handle_info/2,
terminate/2, code_change/3]).
%% thrift_transport callbacks
-export([write/2, read/2, flush/1, close/1]).
-record(trans, {wrapped, % #thrift_transport{}
version :: integer(),
counter :: pid()
}).
-type state() :: #trans{}.
-include("thrift_transport_behaviour.hrl").
-record(state, {cversion :: integer()}).
new(WrappedTransport) ->
case gen_server:start_link(?MODULE, [], []) of
{ok, Pid} ->
Trans = #trans{wrapped = WrappedTransport,
version = 0,
counter = Pid},
thrift_transport:new(?MODULE, Trans);
Else ->
Else
end.
%%====================================================================
%% thrift_transport callbacks
%%====================================================================
write(Transport0 = #trans{wrapped = Wrapped0}, Data) ->
Transport1 = check_version(Transport0),
{Wrapped1, Result} = thrift_transport:write(Wrapped0, Data),
Transport2 = Transport1#trans{wrapped = Wrapped1},
{Transport2, Result}.
flush(Transport0 = #trans{wrapped = Wrapped0}) ->
Transport1 = check_version(Transport0),
{Wrapped1, Result} = thrift_transport:flush(Wrapped0),
Transport2 = Transport1#trans{wrapped = Wrapped1},
{Transport2, Result}.
close(Transport0 = #trans{wrapped = Wrapped0}) ->
Transport1 = check_version(Transport0),
shutdown_counter(Transport1),
{Wrapped1, Result} = thrift_transport:close(Wrapped0),
Transport2 = Transport1#trans{wrapped = Wrapped1},
{Transport2, Result}.
read(Transport0 = #trans{wrapped = Wrapped0}, Len) ->
Transport1 = check_version(Transport0),
{Wrapped1, Result} = thrift_transport:read(Wrapped0, Len),
Transport2 = Transport1#trans{wrapped = Wrapped1},
{Transport2, Result}.
%%====================================================================
%% gen_server callbacks
%%====================================================================
init([]) ->
{ok, #state{cversion = 0}}.
handle_call(check_version, _From, State = #state{cversion = Version}) ->
{reply, Version, State#state{cversion = Version+1}}.
handle_cast(shutdown, State) ->
{stop, normal, State}.
handle_info(_Info, State) -> {noreply, State}.
code_change(_OldVsn, State, _Extra) -> {ok, State}.
terminate(_Reason, _State) -> ok.
%%--------------------------------------------------------------------
%% Internal functions
%%--------------------------------------------------------------------
check_version(Transport = #trans{version = Version, counter = Counter}) ->
case gen_server:call(Counter, check_version) of
Version ->
Transport#trans{version = Version+1};
_Else ->
% State wasn't propagated properly. Die.
erlang:error(state_not_propagated)
end.
shutdown_counter(#trans{counter = Counter}) ->
gen_server:cast(Counter, shutdown).