Upgrading dependency to Thrift 0.12.0

This commit is contained in:
Renan DelValle 2018-11-27 18:03:50 -08:00
parent 3e4590dcc0
commit 356978cb42
No known key found for this signature in database
GPG key ID: C240AD6D6F443EC9
1302 changed files with 101701 additions and 26784 deletions

View file

@ -30,11 +30,22 @@ const (
PROTOCOL_ERROR = 7
)
var defaultApplicationExceptionMessage = map[int32]string{
UNKNOWN_APPLICATION_EXCEPTION: "unknown application exception",
UNKNOWN_METHOD: "unknown method",
INVALID_MESSAGE_TYPE_EXCEPTION: "invalid message type",
WRONG_METHOD_NAME: "wrong method name",
BAD_SEQUENCE_ID: "bad sequence ID",
MISSING_RESULT: "missing result",
INTERNAL_ERROR: "unknown internal error",
PROTOCOL_ERROR: "unknown protocol error",
}
// Application level Thrift exception
type TApplicationException interface {
TException
TypeId() int32
Read(iprot TProtocol) (TApplicationException, error)
Read(iprot TProtocol) error
Write(oprot TProtocol) error
}
@ -44,7 +55,10 @@ type tApplicationException struct {
}
func (e tApplicationException) Error() string {
return e.message
if e.message != "" {
return e.message
}
return defaultApplicationExceptionMessage[e.type_]
}
func NewTApplicationException(type_ int32, message string) TApplicationException {
@ -55,10 +69,11 @@ func (p *tApplicationException) TypeId() int32 {
return p.type_
}
func (p *tApplicationException) Read(iprot TProtocol) (TApplicationException, error) {
func (p *tApplicationException) Read(iprot TProtocol) error {
// TODO: this should really be generated by the compiler
_, err := iprot.ReadStructBegin()
if err != nil {
return nil, err
return err
}
message := ""
@ -67,7 +82,7 @@ func (p *tApplicationException) Read(iprot TProtocol) (TApplicationException, er
for {
_, ttype, id, err := iprot.ReadFieldBegin()
if err != nil {
return nil, err
return err
}
if ttype == STOP {
break
@ -76,33 +91,40 @@ func (p *tApplicationException) Read(iprot TProtocol) (TApplicationException, er
case 1:
if ttype == STRING {
if message, err = iprot.ReadString(); err != nil {
return nil, err
return err
}
} else {
if err = SkipDefaultDepth(iprot, ttype); err != nil {
return nil, err
return err
}
}
case 2:
if ttype == I32 {
if type_, err = iprot.ReadI32(); err != nil {
return nil, err
return err
}
} else {
if err = SkipDefaultDepth(iprot, ttype); err != nil {
return nil, err
return err
}
}
default:
if err = SkipDefaultDepth(iprot, ttype); err != nil {
return nil, err
return err
}
}
if err = iprot.ReadFieldEnd(); err != nil {
return nil, err
return err
}
}
return NewTApplicationException(type_, message), iprot.ReadStructEnd()
if err := iprot.ReadStructEnd(); err != nil {
return err
}
p.message = message
p.type_ = type_
return nil
}
func (p *tApplicationException) Write(oprot TProtocol) (err error) {

View file

@ -25,17 +25,17 @@ import (
func TestTApplicationException(t *testing.T) {
exc := NewTApplicationException(UNKNOWN_APPLICATION_EXCEPTION, "")
if exc.Error() != "" {
if exc.Error() != defaultApplicationExceptionMessage[UNKNOWN_APPLICATION_EXCEPTION] {
t.Fatalf("Expected empty string for exception but found '%s'", exc.Error())
}
if exc.TypeId() != UNKNOWN_APPLICATION_EXCEPTION {
t.Fatalf("Expected type UNKNOWN for exception but found '%s'", exc.TypeId())
t.Fatalf("Expected type UNKNOWN for exception but found '%v'", exc.TypeId())
}
exc = NewTApplicationException(WRONG_METHOD_NAME, "junk_method")
if exc.Error() != "junk_method" {
t.Fatalf("Expected 'junk_method' for exception but found '%s'", exc.Error())
}
if exc.TypeId() != WRONG_METHOD_NAME {
t.Fatalf("Expected type WRONG_METHOD_NAME for exception but found '%s'", exc.TypeId())
t.Fatalf("Expected type WRONG_METHOD_NAME for exception but found '%v'", exc.TypeId())
}
}

View file

@ -21,6 +21,7 @@ package thrift
import (
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
@ -447,9 +448,6 @@ func (p *TBinaryProtocol) ReadBinary() ([]byte, error) {
if size < 0 {
return nil, invalidDataLength
}
if uint64(size) > p.trans.RemainingBytes() {
return nil, invalidDataLength
}
isize := int(size)
buf := make([]byte, isize)
@ -457,8 +455,8 @@ func (p *TBinaryProtocol) ReadBinary() ([]byte, error) {
return buf, NewTProtocolException(err)
}
func (p *TBinaryProtocol) Flush() (err error) {
return NewTProtocolException(p.trans.Flush())
func (p *TBinaryProtocol) Flush(ctx context.Context) (err error) {
return NewTProtocolException(p.trans.Flush(ctx))
}
func (p *TBinaryProtocol) Skip(fieldType TType) (err error) {
@ -480,9 +478,6 @@ func (p *TBinaryProtocol) readStringBody(size int32) (value string, err error) {
if size < 0 {
return "", nil
}
if uint64(size) > p.trans.RemainingBytes() {
return "", invalidDataLength
}
var (
buf bytes.Buffer

View file

@ -21,6 +21,7 @@ package thrift
import (
"bufio"
"context"
)
type TBufferedTransportFactory struct {
@ -78,12 +79,12 @@ func (p *TBufferedTransport) Write(b []byte) (int, error) {
return n, err
}
func (p *TBufferedTransport) Flush() error {
func (p *TBufferedTransport) Flush(ctx context.Context) error {
if err := p.ReadWriter.Flush(); err != nil {
p.ReadWriter.Writer.Reset(p.tp)
return err
}
return p.tp.Flush()
return p.tp.Flush(ctx)
}
func (p *TBufferedTransport) RemainingBytes() (num_bytes uint64) {

View file

@ -0,0 +1,85 @@
package thrift
import (
"context"
"fmt"
)
type TClient interface {
Call(ctx context.Context, method string, args, result TStruct) error
}
type TStandardClient struct {
seqId int32
iprot, oprot TProtocol
}
// TStandardClient implements TClient, and uses the standard message format for Thrift.
// It is not safe for concurrent use.
func NewTStandardClient(inputProtocol, outputProtocol TProtocol) *TStandardClient {
return &TStandardClient{
iprot: inputProtocol,
oprot: outputProtocol,
}
}
func (p *TStandardClient) Send(ctx context.Context, oprot TProtocol, seqId int32, method string, args TStruct) error {
if err := oprot.WriteMessageBegin(method, CALL, seqId); err != nil {
return err
}
if err := args.Write(oprot); err != nil {
return err
}
if err := oprot.WriteMessageEnd(); err != nil {
return err
}
return oprot.Flush(ctx)
}
func (p *TStandardClient) Recv(iprot TProtocol, seqId int32, method string, result TStruct) error {
rMethod, rTypeId, rSeqId, err := iprot.ReadMessageBegin()
if err != nil {
return err
}
if method != rMethod {
return NewTApplicationException(WRONG_METHOD_NAME, fmt.Sprintf("%s: wrong method name", method))
} else if seqId != rSeqId {
return NewTApplicationException(BAD_SEQUENCE_ID, fmt.Sprintf("%s: out of order sequence response", method))
} else if rTypeId == EXCEPTION {
var exception tApplicationException
if err := exception.Read(iprot); err != nil {
return err
}
if err := iprot.ReadMessageEnd(); err != nil {
return err
}
return &exception
} else if rTypeId != REPLY {
return NewTApplicationException(INVALID_MESSAGE_TYPE_EXCEPTION, fmt.Sprintf("%s: invalid message type", method))
}
if err := result.Read(iprot); err != nil {
return err
}
return iprot.ReadMessageEnd()
}
func (p *TStandardClient) Call(ctx context.Context, method string, args, result TStruct) error {
p.seqId++
seqId := p.seqId
if err := p.Send(ctx, p.oprot, seqId, method, args); err != nil {
return err
}
// method is oneway
if result == nil {
return nil
}
return p.Recv(p.iprot, seqId, method, result)
}

View file

@ -19,12 +19,12 @@
package thrift
// A processor is a generic object which operates upon an input stream and
// writes to some output stream.
type TProcessor interface {
Process(in, out TProtocol) (bool, TException)
import "context"
type mockProcessor struct {
ProcessFunc func(in, out TProtocol) (bool, TException)
}
type TProcessorFunction interface {
Process(seqId int32, in, out TProtocol) (bool, TException)
func (m *mockProcessor) Process(ctx context.Context, in, out TProtocol) (bool, TException) {
return m.ProcessFunc(in, out)
}

View file

@ -20,6 +20,7 @@
package thrift
import (
"context"
"encoding/binary"
"fmt"
"io"
@ -561,9 +562,6 @@ func (p *TCompactProtocol) ReadString() (value string, err error) {
if length < 0 {
return "", invalidDataLength
}
if uint64(length) > p.trans.RemainingBytes() {
return "", invalidDataLength
}
if length == 0 {
return "", nil
@ -590,17 +588,14 @@ func (p *TCompactProtocol) ReadBinary() (value []byte, err error) {
if length < 0 {
return nil, invalidDataLength
}
if uint64(length) > p.trans.RemainingBytes() {
return nil, invalidDataLength
}
buf := make([]byte, length)
_, e = io.ReadFull(p.trans, buf)
return buf, NewTProtocolException(e)
}
func (p *TCompactProtocol) Flush() (err error) {
return NewTProtocolException(p.trans.Flush())
func (p *TCompactProtocol) Flush(ctx context.Context) (err error) {
return NewTProtocolException(p.trans.Flush(ctx))
}
func (p *TCompactProtocol) Skip(fieldType TType) (err error) {
@ -806,7 +801,7 @@ func (p *TCompactProtocol) getTType(t tCompactType) (TType, error) {
case COMPACT_STRUCT:
return STRUCT, nil
}
return STOP, TException(fmt.Errorf("don't know what type: %s", t&0x0f))
return STOP, TException(fmt.Errorf("don't know what type: %v", t&0x0f))
}
// Given a TType value, find the appropriate TCompactProtocol.Types constant.

View file

@ -26,11 +26,18 @@ import (
func TestReadWriteCompactProtocol(t *testing.T) {
ReadWriteProtocolTest(t, NewTCompactProtocolFactory())
transports := []TTransport{
NewTMemoryBuffer(),
NewStreamTransportRW(bytes.NewBuffer(make([]byte, 0, 16384))),
NewTFramedTransport(NewTMemoryBuffer()),
}
zlib0, _ := NewTZlibTransport(NewTMemoryBuffer(), 0)
zlib6, _ := NewTZlibTransport(NewTMemoryBuffer(), 6)
zlib9, _ := NewTZlibTransport(NewTFramedTransport(NewTMemoryBuffer()), 9)
transports = append(transports, zlib0, zlib6, zlib9)
for _, trans := range transports {
p := NewTCompactProtocol(trans)
ReadWriteBool(t, p, trans)

View file

@ -0,0 +1,24 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import "context"
var defaultCtx = context.Background()

View file

@ -20,6 +20,7 @@
package thrift
import (
"context"
"log"
)
@ -258,8 +259,8 @@ func (tdp *TDebugProtocol) Skip(fieldType TType) (err error) {
log.Printf("%sSkip(fieldType=%#v) (err=%#v)", tdp.LogPrefix, fieldType, err)
return
}
func (tdp *TDebugProtocol) Flush() (err error) {
err = tdp.Delegate.Flush()
func (tdp *TDebugProtocol) Flush(ctx context.Context) (err error) {
err = tdp.Delegate.Flush(ctx)
log.Printf("%sFlush() (err=%#v)", tdp.LogPrefix, err)
return
}

View file

@ -22,6 +22,7 @@ package thrift
import (
"bufio"
"bytes"
"context"
"encoding/binary"
"fmt"
"io"
@ -135,21 +136,23 @@ func (p *TFramedTransport) WriteString(s string) (n int, err error) {
return p.buf.WriteString(s)
}
func (p *TFramedTransport) Flush() error {
func (p *TFramedTransport) Flush(ctx context.Context) error {
size := p.buf.Len()
buf := p.buffer[:4]
binary.BigEndian.PutUint32(buf, uint32(size))
_, err := p.transport.Write(buf)
if err != nil {
p.buf.Truncate(0)
return NewTTransportExceptionFromError(err)
}
if size > 0 {
if n, err := p.buf.WriteTo(p.transport); err != nil {
print("Error while flushing write buffer of size ", size, " to transport, only wrote ", n, " bytes: ", err.Error(), "\n")
p.buf.Truncate(0)
return NewTTransportExceptionFromError(err)
}
}
err = p.transport.Flush()
err = p.transport.Flush(ctx)
return NewTTransportExceptionFromError(err)
}

View file

@ -21,6 +21,7 @@ package thrift
import (
"bytes"
"context"
"io"
"io/ioutil"
"net/http"
@ -181,7 +182,7 @@ func (p *THttpClient) WriteString(s string) (n int, err error) {
return p.requestBuffer.WriteString(s)
}
func (p *THttpClient) Flush() error {
func (p *THttpClient) Flush(ctx context.Context) error {
// Close any previous response body to avoid leaking connections.
p.closeResponse()
@ -190,6 +191,9 @@ func (p *THttpClient) Flush() error {
return NewTTransportExceptionFromError(err)
}
req.Header = p.header
if ctx != nil {
req = req.WithContext(ctx)
}
response, err := p.client.Do(req)
if err != nil {
return NewTTransportExceptionFromError(err)

View file

@ -19,16 +19,45 @@
package thrift
import "net/http"
import (
"compress/gzip"
"io"
"net/http"
"strings"
)
// NewThriftHandlerFunc is a function that create a ready to use Apache Thrift Handler function
func NewThriftHandlerFunc(processor TProcessor,
inPfactory, outPfactory TProtocolFactory) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
return gz(func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Content-Type", "application/x-thrift")
transport := NewStreamTransport(r.Body, w)
processor.Process(inPfactory.GetProtocol(transport), outPfactory.GetProtocol(transport))
processor.Process(r.Context(), inPfactory.GetProtocol(transport), outPfactory.GetProtocol(transport))
})
}
// gz transparently compresses the HTTP response if the client supports it.
func gz(handler http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
handler(w, r)
return
}
w.Header().Set("Content-Encoding", "gzip")
gz := gzip.NewWriter(w)
defer gz.Close()
gzw := gzipResponseWriter{Writer: gz, ResponseWriter: w}
handler(gzw, r)
}
}
type gzipResponseWriter struct {
io.Writer
http.ResponseWriter
}
func (w gzipResponseWriter) Write(b []byte) (int, error) {
return w.Writer.Write(b)
}

View file

@ -21,6 +21,7 @@ package thrift
import (
"bufio"
"context"
"io"
)
@ -138,7 +139,7 @@ func (p *StreamTransport) Close() error {
}
// Flushes the underlying output stream if not null.
func (p *StreamTransport) Flush() error {
func (p *StreamTransport) Flush(ctx context.Context) error {
if p.Writer == nil {
return NewTTransportException(NOT_OPEN, "Cannot flush null outputStream")
}

View file

@ -20,6 +20,7 @@
package thrift
import (
"context"
"encoding/base64"
"fmt"
)
@ -438,10 +439,10 @@ func (p *TJSONProtocol) ReadBinary() ([]byte, error) {
return v, p.ParsePostValue()
}
func (p *TJSONProtocol) Flush() (err error) {
func (p *TJSONProtocol) Flush(ctx context.Context) (err error) {
err = p.writer.Flush()
if err == nil {
err = p.trans.Flush()
err = p.trans.Flush(ctx)
}
return NewTProtocolException(err)
}

View file

@ -20,6 +20,7 @@
package thrift
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
@ -36,7 +37,7 @@ func TestWriteJSONProtocolBool(t *testing.T) {
if e := p.WriteBool(value); e != nil {
t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error())
}
if e := p.Flush(); e != nil {
if e := p.Flush(context.Background()); e != nil {
t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error())
}
s := trans.String()
@ -68,7 +69,7 @@ func TestReadJSONProtocolBool(t *testing.T) {
} else {
trans.Write([]byte{'0'}) // not JSON_FALSE
}
trans.Flush()
trans.Flush(context.Background())
s := trans.String()
v, e := p.ReadBool()
if e != nil {
@ -94,7 +95,7 @@ func TestWriteJSONProtocolByte(t *testing.T) {
if e := p.WriteByte(value); e != nil {
t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error())
}
if e := p.Flush(); e != nil {
if e := p.Flush(context.Background()); e != nil {
t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error())
}
s := trans.String()
@ -116,7 +117,7 @@ func TestReadJSONProtocolByte(t *testing.T) {
trans := NewTMemoryBuffer()
p := NewTJSONProtocol(trans)
trans.WriteString(strconv.Itoa(int(value)))
trans.Flush()
trans.Flush(context.Background())
s := trans.String()
v, e := p.ReadByte()
if e != nil {
@ -141,7 +142,7 @@ func TestWriteJSONProtocolI16(t *testing.T) {
if e := p.WriteI16(value); e != nil {
t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error())
}
if e := p.Flush(); e != nil {
if e := p.Flush(context.Background()); e != nil {
t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error())
}
s := trans.String()
@ -163,7 +164,7 @@ func TestReadJSONProtocolI16(t *testing.T) {
trans := NewTMemoryBuffer()
p := NewTJSONProtocol(trans)
trans.WriteString(strconv.Itoa(int(value)))
trans.Flush()
trans.Flush(context.Background())
s := trans.String()
v, e := p.ReadI16()
if e != nil {
@ -188,7 +189,7 @@ func TestWriteJSONProtocolI32(t *testing.T) {
if e := p.WriteI32(value); e != nil {
t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error())
}
if e := p.Flush(); e != nil {
if e := p.Flush(context.Background()); e != nil {
t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error())
}
s := trans.String()
@ -210,7 +211,7 @@ func TestReadJSONProtocolI32(t *testing.T) {
trans := NewTMemoryBuffer()
p := NewTJSONProtocol(trans)
trans.WriteString(strconv.Itoa(int(value)))
trans.Flush()
trans.Flush(context.Background())
s := trans.String()
v, e := p.ReadI32()
if e != nil {
@ -235,7 +236,7 @@ func TestWriteJSONProtocolI64(t *testing.T) {
if e := p.WriteI64(value); e != nil {
t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error())
}
if e := p.Flush(); e != nil {
if e := p.Flush(context.Background()); e != nil {
t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error())
}
s := trans.String()
@ -257,7 +258,7 @@ func TestReadJSONProtocolI64(t *testing.T) {
trans := NewTMemoryBuffer()
p := NewTJSONProtocol(trans)
trans.WriteString(strconv.FormatInt(value, 10))
trans.Flush()
trans.Flush(context.Background())
s := trans.String()
v, e := p.ReadI64()
if e != nil {
@ -282,7 +283,7 @@ func TestWriteJSONProtocolDouble(t *testing.T) {
if e := p.WriteDouble(value); e != nil {
t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error())
}
if e := p.Flush(); e != nil {
if e := p.Flush(context.Background()); e != nil {
t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error())
}
s := trans.String()
@ -319,7 +320,7 @@ func TestReadJSONProtocolDouble(t *testing.T) {
p := NewTJSONProtocol(trans)
n := NewNumericFromDouble(value)
trans.WriteString(n.String())
trans.Flush()
trans.Flush(context.Background())
s := trans.String()
v, e := p.ReadDouble()
if e != nil {
@ -358,7 +359,7 @@ func TestWriteJSONProtocolString(t *testing.T) {
if e := p.WriteString(value); e != nil {
t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error())
}
if e := p.Flush(); e != nil {
if e := p.Flush(context.Background()); e != nil {
t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error())
}
s := trans.String()
@ -380,7 +381,7 @@ func TestReadJSONProtocolString(t *testing.T) {
trans := NewTMemoryBuffer()
p := NewTJSONProtocol(trans)
trans.WriteString(jsonQuote(value))
trans.Flush()
trans.Flush(context.Background())
s := trans.String()
v, e := p.ReadString()
if e != nil {
@ -409,7 +410,7 @@ func TestWriteJSONProtocolBinary(t *testing.T) {
if e := p.WriteBinary(value); e != nil {
t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error())
}
if e := p.Flush(); e != nil {
if e := p.Flush(context.Background()); e != nil {
t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error())
}
s := trans.String()
@ -441,7 +442,7 @@ func TestReadJSONProtocolBinary(t *testing.T) {
trans := NewTMemoryBuffer()
p := NewTJSONProtocol(trans)
trans.WriteString(jsonQuote(b64String))
trans.Flush()
trans.Flush(context.Background())
s := trans.String()
v, e := p.ReadBinary()
if e != nil {
@ -474,7 +475,7 @@ func TestWriteJSONProtocolList(t *testing.T) {
}
}
p.WriteListEnd()
if e := p.Flush(); e != nil {
if e := p.Flush(context.Background()); e != nil {
t.Fatalf("Unable to write %s due to error flushing: %s", thetype, e.Error())
}
str := trans.String()
@ -528,7 +529,7 @@ func TestWriteJSONProtocolSet(t *testing.T) {
}
}
p.WriteSetEnd()
if e := p.Flush(); e != nil {
if e := p.Flush(context.Background()); e != nil {
t.Fatalf("Unable to write %s due to error flushing: %s", thetype, e.Error())
}
str := trans.String()
@ -585,12 +586,12 @@ func TestWriteJSONProtocolMap(t *testing.T) {
}
}
p.WriteMapEnd()
if e := p.Flush(); e != nil {
if e := p.Flush(context.Background()); e != nil {
t.Fatalf("Unable to write %s due to error flushing: %s", thetype, e.Error())
}
str := trans.String()
if str[0] != '[' || str[len(str)-1] != ']' {
t.Fatalf("Bad value for %s, wrote: %q, in go: %q", thetype, str, DOUBLE_VALUES)
t.Fatalf("Bad value for %s, wrote: %v, in go: %v", thetype, str, DOUBLE_VALUES)
}
expectedKeyType, expectedValueType, expectedSize, err := p.ReadMapBegin()
if err != nil {

View file

@ -21,6 +21,7 @@ package thrift
import (
"bytes"
"context"
)
// Memory buffer-based implementation of the TTransport interface.
@ -70,7 +71,7 @@ func (p *TMemoryBuffer) Close() error {
}
// Flushing a memory buffer is a no-op
func (p *TMemoryBuffer) Flush() error {
func (p *TMemoryBuffer) Flush(ctx context.Context) error {
return nil
}

View file

@ -20,6 +20,7 @@
package thrift
import (
"context"
"fmt"
"strings"
)
@ -127,7 +128,7 @@ func (t *TMultiplexedProcessor) RegisterProcessor(name string, processor TProces
t.serviceProcessorMap[name] = processor
}
func (t *TMultiplexedProcessor) Process(in, out TProtocol) (bool, TException) {
func (t *TMultiplexedProcessor) Process(ctx context.Context, in, out TProtocol) (bool, TException) {
name, typeId, seqid, err := in.ReadMessageBegin()
if err != nil {
return false, err
@ -140,7 +141,7 @@ func (t *TMultiplexedProcessor) Process(in, out TProtocol) (bool, TException) {
if len(v) != 2 {
if t.DefaultProcessor != nil {
smb := NewStoredMessageProtocol(in, name, typeId, seqid)
return t.DefaultProcessor.Process(smb, out)
return t.DefaultProcessor.Process(ctx, smb, out)
}
return false, fmt.Errorf("Service name not found in message name: %s. Did you forget to use a TMultiplexProtocol in your client?", name)
}
@ -149,7 +150,7 @@ func (t *TMultiplexedProcessor) Process(in, out TProtocol) (bool, TException) {
return false, fmt.Errorf("Service name not found: %s. Did you forget to call registerProcessor()?", v[0])
}
smb := NewStoredMessageProtocol(in, v[1], typeId, seqid)
return actualProcessor.Process(smb, out)
return actualProcessor.Process(ctx, smb, out)
}
//Protocol that use stored message for ReadMessageBegin

View file

@ -19,6 +19,18 @@
package thrift
import "context"
// A processor is a generic object which operates upon an input stream and
// writes to some output stream.
type TProcessor interface {
Process(ctx context.Context, in, out TProtocol) (bool, TException)
}
type TProcessorFunction interface {
Process(ctx context.Context, seqId int32, in, out TProtocol) (bool, TException)
}
// The default processor factory just returns a singleton
// instance.
type TProcessorFactory interface {

View file

@ -20,7 +20,9 @@
package thrift
import (
"context"
"errors"
"fmt"
)
const (
@ -73,7 +75,7 @@ type TProtocol interface {
ReadBinary() (value []byte, err error)
Skip(fieldType TType) (err error)
Flush() (err error)
Flush(ctx context.Context) (err error)
Transport() TTransport
}
@ -170,6 +172,8 @@ func Skip(self TProtocol, fieldType TType, maxDepth int) (err error) {
}
}
return self.ReadListEnd()
default:
return NewTProtocolExceptionWithType(INVALID_DATA, errors.New(fmt.Sprintf("Unknown data type %d", fieldType)))
}
return nil
}

View file

@ -21,6 +21,7 @@ package thrift
import (
"bytes"
"context"
"io/ioutil"
"math"
"net"
@ -31,7 +32,6 @@ import (
const PROTOCOL_BINARY_DATA_SIZE = 155
var (
data string // test data for writing
protocol_bdata []byte // test data for writing; same as data
BOOL_VALUES []bool
BYTE_VALUES []int8
@ -47,7 +47,6 @@ func init() {
for i := 0; i < PROTOCOL_BINARY_DATA_SIZE; i++ {
protocol_bdata[i] = byte((i + 'a') % 255)
}
data = string(protocol_bdata)
BOOL_VALUES = []bool{false, true, false, false, true}
BYTE_VALUES = []int8{117, 0, 1, 32, 127, -128, -1}
INT16_VALUES = []int16{459, 0, 1, -1, -128, 127, 32767, -32768}
@ -120,6 +119,9 @@ func ReadWriteProtocolTest(t *testing.T, protocolFactory TProtocolFactory) {
NewTMemoryBufferTransportFactory(1024),
NewStreamTransportFactory(buf, buf, true),
NewTFramedTransportFactory(NewTMemoryBufferTransportFactory(1024)),
NewTZlibTransportFactoryWithFactory(0, NewTMemoryBufferTransportFactory(1024)),
NewTZlibTransportFactoryWithFactory(6, NewTMemoryBufferTransportFactory(1024)),
NewTZlibTransportFactoryWithFactory(9, NewTFramedTransportFactory(NewTMemoryBufferTransportFactory(1024))),
NewTHttpPostClientTransportFactory("http://" + addr.String()),
}
for _, tf := range transports {
@ -227,17 +229,17 @@ func ReadWriteBool(t testing.TB, p TProtocol, trans TTransport) {
for k, v := range BOOL_VALUES {
err = p.WriteBool(v)
if err != nil {
t.Errorf("%s: %T %T %q Error writing bool in list at index %d: %q", "ReadWriteBool", p, trans, err, k, v)
t.Errorf("%s: %T %T %v Error writing bool in list at index %v: %v", "ReadWriteBool", p, trans, err, k, v)
}
}
p.WriteListEnd()
if err != nil {
t.Errorf("%s: %T %T %q Error writing list end: %q", "ReadWriteBool", p, trans, err, BOOL_VALUES)
t.Errorf("%s: %T %T %v Error writing list end: %v", "ReadWriteBool", p, trans, err, BOOL_VALUES)
}
p.Flush()
p.Flush(context.Background())
thetype2, thelen2, err := p.ReadListBegin()
if err != nil {
t.Errorf("%s: %T %T %q Error reading list: %q", "ReadWriteBool", p, trans, err, BOOL_VALUES)
t.Errorf("%s: %T %T %v Error reading list: %v", "ReadWriteBool", p, trans, err, BOOL_VALUES)
}
_, ok := p.(*TSimpleJSONProtocol)
if !ok {
@ -245,16 +247,16 @@ func ReadWriteBool(t testing.TB, p TProtocol, trans TTransport) {
t.Errorf("%s: %T %T type %s != type %s", "ReadWriteBool", p, trans, thetype, thetype2)
}
if thelen != thelen2 {
t.Errorf("%s: %T %T len %s != len %s", "ReadWriteBool", p, trans, thelen, thelen2)
t.Errorf("%s: %T %T len %v != len %v", "ReadWriteBool", p, trans, thelen, thelen2)
}
}
for k, v := range BOOL_VALUES {
value, err := p.ReadBool()
if err != nil {
t.Errorf("%s: %T %T %q Error reading bool at index %d: %q", "ReadWriteBool", p, trans, err, k, v)
t.Errorf("%s: %T %T %v Error reading bool at index %v: %v", "ReadWriteBool", p, trans, err, k, v)
}
if v != value {
t.Errorf("%s: index %d %q %q %q != %q", "ReadWriteBool", k, p, trans, v, value)
t.Errorf("%s: index %v %v %v %v != %v", "ReadWriteBool", k, p, trans, v, value)
}
}
err = p.ReadListEnd()
@ -280,7 +282,7 @@ func ReadWriteByte(t testing.TB, p TProtocol, trans TTransport) {
if err != nil {
t.Errorf("%s: %T %T %q Error writing list end: %q", "ReadWriteByte", p, trans, err, BYTE_VALUES)
}
err = p.Flush()
err = p.Flush(context.Background())
if err != nil {
t.Errorf("%s: %T %T %q Error flushing list of bytes: %q", "ReadWriteByte", p, trans, err, BYTE_VALUES)
}
@ -294,7 +296,7 @@ func ReadWriteByte(t testing.TB, p TProtocol, trans TTransport) {
t.Errorf("%s: %T %T type %s != type %s", "ReadWriteByte", p, trans, thetype, thetype2)
}
if thelen != thelen2 {
t.Errorf("%s: %T %T len %s != len %s", "ReadWriteByte", p, trans, thelen, thelen2)
t.Errorf("%s: %T %T len %v != len %v", "ReadWriteByte", p, trans, thelen, thelen2)
}
}
for k, v := range BYTE_VALUES {
@ -320,7 +322,7 @@ func ReadWriteI16(t testing.TB, p TProtocol, trans TTransport) {
p.WriteI16(v)
}
p.WriteListEnd()
p.Flush()
p.Flush(context.Background())
thetype2, thelen2, err := p.ReadListBegin()
if err != nil {
t.Errorf("%s: %T %T %q Error reading list: %q", "ReadWriteI16", p, trans, err, INT16_VALUES)
@ -331,7 +333,7 @@ func ReadWriteI16(t testing.TB, p TProtocol, trans TTransport) {
t.Errorf("%s: %T %T type %s != type %s", "ReadWriteI16", p, trans, thetype, thetype2)
}
if thelen != thelen2 {
t.Errorf("%s: %T %T len %s != len %s", "ReadWriteI16", p, trans, thelen, thelen2)
t.Errorf("%s: %T %T len %v != len %v", "ReadWriteI16", p, trans, thelen, thelen2)
}
}
for k, v := range INT16_VALUES {
@ -357,7 +359,7 @@ func ReadWriteI32(t testing.TB, p TProtocol, trans TTransport) {
p.WriteI32(v)
}
p.WriteListEnd()
p.Flush()
p.Flush(context.Background())
thetype2, thelen2, err := p.ReadListBegin()
if err != nil {
t.Errorf("%s: %T %T %q Error reading list: %q", "ReadWriteI32", p, trans, err, INT32_VALUES)
@ -368,7 +370,7 @@ func ReadWriteI32(t testing.TB, p TProtocol, trans TTransport) {
t.Errorf("%s: %T %T type %s != type %s", "ReadWriteI32", p, trans, thetype, thetype2)
}
if thelen != thelen2 {
t.Errorf("%s: %T %T len %s != len %s", "ReadWriteI32", p, trans, thelen, thelen2)
t.Errorf("%s: %T %T len %v != len %v", "ReadWriteI32", p, trans, thelen, thelen2)
}
}
for k, v := range INT32_VALUES {
@ -393,7 +395,7 @@ func ReadWriteI64(t testing.TB, p TProtocol, trans TTransport) {
p.WriteI64(v)
}
p.WriteListEnd()
p.Flush()
p.Flush(context.Background())
thetype2, thelen2, err := p.ReadListBegin()
if err != nil {
t.Errorf("%s: %T %T %q Error reading list: %q", "ReadWriteI64", p, trans, err, INT64_VALUES)
@ -404,7 +406,7 @@ func ReadWriteI64(t testing.TB, p TProtocol, trans TTransport) {
t.Errorf("%s: %T %T type %s != type %s", "ReadWriteI64", p, trans, thetype, thetype2)
}
if thelen != thelen2 {
t.Errorf("%s: %T %T len %s != len %s", "ReadWriteI64", p, trans, thelen, thelen2)
t.Errorf("%s: %T %T len %v != len %v", "ReadWriteI64", p, trans, thelen, thelen2)
}
}
for k, v := range INT64_VALUES {
@ -429,28 +431,28 @@ func ReadWriteDouble(t testing.TB, p TProtocol, trans TTransport) {
p.WriteDouble(v)
}
p.WriteListEnd()
p.Flush()
p.Flush(context.Background())
thetype2, thelen2, err := p.ReadListBegin()
if err != nil {
t.Errorf("%s: %T %T %q Error reading list: %q", "ReadWriteDouble", p, trans, err, DOUBLE_VALUES)
t.Errorf("%s: %T %T %v Error reading list: %v", "ReadWriteDouble", p, trans, err, DOUBLE_VALUES)
}
if thetype != thetype2 {
t.Errorf("%s: %T %T type %s != type %s", "ReadWriteDouble", p, trans, thetype, thetype2)
}
if thelen != thelen2 {
t.Errorf("%s: %T %T len %s != len %s", "ReadWriteDouble", p, trans, thelen, thelen2)
t.Errorf("%s: %T %T len %v != len %v", "ReadWriteDouble", p, trans, thelen, thelen2)
}
for k, v := range DOUBLE_VALUES {
value, err := p.ReadDouble()
if err != nil {
t.Errorf("%s: %T %T %q Error reading double at index %d: %q", "ReadWriteDouble", p, trans, err, k, v)
t.Errorf("%s: %T %T %q Error reading double at index %d: %v", "ReadWriteDouble", p, trans, err, k, v)
}
if math.IsNaN(v) {
if !math.IsNaN(value) {
t.Errorf("%s: %T %T math.IsNaN(%q) != math.IsNaN(%q)", "ReadWriteDouble", p, trans, v, value)
t.Errorf("%s: %T %T math.IsNaN(%v) != math.IsNaN(%v)", "ReadWriteDouble", p, trans, v, value)
}
} else if v != value {
t.Errorf("%s: %T %T %v != %q", "ReadWriteDouble", p, trans, v, value)
t.Errorf("%s: %T %T %v != %v", "ReadWriteDouble", p, trans, v, value)
}
}
err = p.ReadListEnd()
@ -467,7 +469,7 @@ func ReadWriteString(t testing.TB, p TProtocol, trans TTransport) {
p.WriteString(v)
}
p.WriteListEnd()
p.Flush()
p.Flush(context.Background())
thetype2, thelen2, err := p.ReadListBegin()
if err != nil {
t.Errorf("%s: %T %T %q Error reading list: %q", "ReadWriteString", p, trans, err, STRING_VALUES)
@ -478,7 +480,7 @@ func ReadWriteString(t testing.TB, p TProtocol, trans TTransport) {
t.Errorf("%s: %T %T type %s != type %s", "ReadWriteString", p, trans, thetype, thetype2)
}
if thelen != thelen2 {
t.Errorf("%s: %T %T len %s != len %s", "ReadWriteString", p, trans, thelen, thelen2)
t.Errorf("%s: %T %T len %v != len %v", "ReadWriteString", p, trans, thelen, thelen2)
}
}
for k, v := range STRING_VALUES {
@ -487,7 +489,7 @@ func ReadWriteString(t testing.TB, p TProtocol, trans TTransport) {
t.Errorf("%s: %T %T %q Error reading string at index %d: %q", "ReadWriteString", p, trans, err, k, v)
}
if v != value {
t.Errorf("%s: %T %T %d != %d", "ReadWriteString", p, trans, v, value)
t.Errorf("%s: %T %T %v != %v", "ReadWriteString", p, trans, v, value)
}
}
if err != nil {
@ -498,7 +500,7 @@ func ReadWriteString(t testing.TB, p TProtocol, trans TTransport) {
func ReadWriteBinary(t testing.TB, p TProtocol, trans TTransport) {
v := protocol_bdata
p.WriteBinary(v)
p.Flush()
p.Flush(context.Background())
value, err := p.ReadBinary()
if err != nil {
t.Errorf("%s: %T %T Unable to read binary: %s", "ReadWriteBinary", p, trans, err.Error())

View file

@ -19,6 +19,10 @@
package thrift
import (
"context"
)
type TSerializer struct {
Transport *TMemoryBuffer
Protocol TProtocol
@ -38,35 +42,35 @@ func NewTSerializer() *TSerializer {
protocol}
}
func (t *TSerializer) WriteString(msg TStruct) (s string, err error) {
func (t *TSerializer) WriteString(ctx context.Context, msg TStruct) (s string, err error) {
t.Transport.Reset()
if err = msg.Write(t.Protocol); err != nil {
return
}
if err = t.Protocol.Flush(); err != nil {
if err = t.Protocol.Flush(ctx); err != nil {
return
}
if err = t.Transport.Flush(); err != nil {
if err = t.Transport.Flush(ctx); err != nil {
return
}
return t.Transport.String(), nil
}
func (t *TSerializer) Write(msg TStruct) (b []byte, err error) {
func (t *TSerializer) Write(ctx context.Context, msg TStruct) (b []byte, err error) {
t.Transport.Reset()
if err = msg.Write(t.Protocol); err != nil {
return
}
if err = t.Protocol.Flush(); err != nil {
if err = t.Protocol.Flush(ctx); err != nil {
return
}
if err = t.Transport.Flush(); err != nil {
if err = t.Transport.Flush(ctx); err != nil {
return
}

View file

@ -20,6 +20,7 @@
package thrift
import (
"context"
"errors"
"fmt"
"testing"
@ -88,7 +89,7 @@ func ProtocolTest1(test *testing.T, pf ProtocolFactory) (bool, error) {
m.StringSet = make(map[string]struct{}, 5)
m.E = 2
s, err := t.WriteString(&m)
s, err := t.WriteString(context.Background(), &m)
if err != nil {
return false, errors.New(fmt.Sprintf("Unable to Serialize struct\n\t %s", err))
}
@ -122,7 +123,7 @@ func ProtocolTest2(test *testing.T, pf ProtocolFactory) (bool, error) {
m.StringSet = make(map[string]struct{}, 5)
m.E = 2
s, err := t.WriteString(&m)
s, err := t.WriteString(context.Background(), &m)
if err != nil {
return false, errors.New(fmt.Sprintf("Unable to Serialize struct\n\t %s", err))

View file

@ -47,7 +47,14 @@ func NewTServerSocketTimeout(listenAddr string, clientTimeout time.Duration) (*T
return &TServerSocket{addr: addr, clientTimeout: clientTimeout}, nil
}
// Creates a TServerSocket from a net.Addr
func NewTServerSocketFromAddrTimeout(addr net.Addr, clientTimeout time.Duration) *TServerSocket {
return &TServerSocket{addr: addr, clientTimeout: clientTimeout}
}
func (p *TServerSocket) Listen() error {
p.mu.Lock()
defer p.mu.Unlock()
if p.IsListening() {
return nil
}
@ -67,10 +74,13 @@ func (p *TServerSocket) Accept() (TTransport, error) {
if interrupted {
return nil, errTransportInterrupted
}
if p.listener == nil {
listener := p.listener
if listener == nil {
return nil, NewTTransportException(NOT_OPEN, "No underlying server socket")
}
conn, err := p.listener.Accept()
conn, err := listener.Accept()
if err != nil {
return nil, NewTTransportExceptionFromError(err)
}
@ -84,6 +94,8 @@ func (p *TServerSocket) IsListening() bool {
// Connects the socket, creating a new socket object if necessary.
func (p *TServerSocket) Open() error {
p.mu.Lock()
defer p.mu.Unlock()
if p.IsListening() {
return NewTTransportException(ALREADY_OPEN, "Server socket already open")
}
@ -114,9 +126,9 @@ func (p *TServerSocket) Close() error {
func (p *TServerSocket) Interrupt() error {
p.mu.Lock()
defer p.mu.Unlock()
p.interrupted = true
p.Close()
p.mu.Unlock()
return nil
}

View file

@ -41,6 +41,16 @@ func TestSocketIsntListeningAfterInterrupt(t *testing.T) {
}
}
func TestSocketConcurrency(t *testing.T) {
host := "127.0.0.1"
port := 9090
addr := fmt.Sprintf("%s:%d", host, port)
socket := CreateServerSocket(t, addr)
go func() { socket.Listen() }()
go func() { socket.Interrupt() }()
}
func CreateServerSocket(t *testing.T, addr string) *TServerSocket {
socket, err := NewTServerSocket(addr)
if err != nil {

View file

@ -22,6 +22,7 @@ package thrift
import (
"bufio"
"bytes"
"context"
"encoding/base64"
"encoding/json"
"fmt"
@ -552,7 +553,7 @@ func (p *TSimpleJSONProtocol) ReadBinary() ([]byte, error) {
return v, p.ParsePostValue()
}
func (p *TSimpleJSONProtocol) Flush() (err error) {
func (p *TSimpleJSONProtocol) Flush(ctx context.Context) (err error) {
return NewTProtocolException(p.writer.Flush())
}
@ -1064,7 +1065,7 @@ func (p *TSimpleJSONProtocol) ParseListEnd() error {
for _, char := range line {
switch char {
default:
e := fmt.Errorf("Expecting end of list \"]\", but found: \"", line, "\"")
e := fmt.Errorf("Expecting end of list \"]\", but found: \"%v\"", line)
return NewTProtocolExceptionWithType(INVALID_DATA, e)
case ' ', '\n', '\r', '\t', rune(JSON_RBRACKET[0]):
break

View file

@ -20,6 +20,7 @@
package thrift
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
@ -37,7 +38,7 @@ func TestWriteSimpleJSONProtocolBool(t *testing.T) {
if e := p.WriteBool(value); e != nil {
t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error())
}
if e := p.Flush(); e != nil {
if e := p.Flush(context.Background()); e != nil {
t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error())
}
s := trans.String()
@ -63,7 +64,7 @@ func TestReadSimpleJSONProtocolBool(t *testing.T) {
} else {
trans.Write(JSON_FALSE)
}
trans.Flush()
trans.Flush(context.Background())
s := trans.String()
v, e := p.ReadBool()
if e != nil {
@ -88,7 +89,7 @@ func TestWriteSimpleJSONProtocolByte(t *testing.T) {
if e := p.WriteByte(value); e != nil {
t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error())
}
if e := p.Flush(); e != nil {
if e := p.Flush(context.Background()); e != nil {
t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error())
}
s := trans.String()
@ -110,7 +111,7 @@ func TestReadSimpleJSONProtocolByte(t *testing.T) {
trans := NewTMemoryBuffer()
p := NewTSimpleJSONProtocol(trans)
trans.WriteString(strconv.Itoa(int(value)))
trans.Flush()
trans.Flush(context.Background())
s := trans.String()
v, e := p.ReadByte()
if e != nil {
@ -135,7 +136,7 @@ func TestWriteSimpleJSONProtocolI16(t *testing.T) {
if e := p.WriteI16(value); e != nil {
t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error())
}
if e := p.Flush(); e != nil {
if e := p.Flush(context.Background()); e != nil {
t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error())
}
s := trans.String()
@ -157,7 +158,7 @@ func TestReadSimpleJSONProtocolI16(t *testing.T) {
trans := NewTMemoryBuffer()
p := NewTSimpleJSONProtocol(trans)
trans.WriteString(strconv.Itoa(int(value)))
trans.Flush()
trans.Flush(context.Background())
s := trans.String()
v, e := p.ReadI16()
if e != nil {
@ -182,7 +183,7 @@ func TestWriteSimpleJSONProtocolI32(t *testing.T) {
if e := p.WriteI32(value); e != nil {
t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error())
}
if e := p.Flush(); e != nil {
if e := p.Flush(context.Background()); e != nil {
t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error())
}
s := trans.String()
@ -204,7 +205,7 @@ func TestReadSimpleJSONProtocolI32(t *testing.T) {
trans := NewTMemoryBuffer()
p := NewTSimpleJSONProtocol(trans)
trans.WriteString(strconv.Itoa(int(value)))
trans.Flush()
trans.Flush(context.Background())
s := trans.String()
v, e := p.ReadI32()
if e != nil {
@ -228,7 +229,7 @@ func TestReadSimpleJSONProtocolI32Null(t *testing.T) {
trans := NewTMemoryBuffer()
p := NewTSimpleJSONProtocol(trans)
trans.WriteString(value)
trans.Flush()
trans.Flush(context.Background())
s := trans.String()
v, e := p.ReadI32()
@ -250,7 +251,7 @@ func TestWriteSimpleJSONProtocolI64(t *testing.T) {
if e := p.WriteI64(value); e != nil {
t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error())
}
if e := p.Flush(); e != nil {
if e := p.Flush(context.Background()); e != nil {
t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error())
}
s := trans.String()
@ -272,7 +273,7 @@ func TestReadSimpleJSONProtocolI64(t *testing.T) {
trans := NewTMemoryBuffer()
p := NewTSimpleJSONProtocol(trans)
trans.WriteString(strconv.FormatInt(value, 10))
trans.Flush()
trans.Flush(context.Background())
s := trans.String()
v, e := p.ReadI64()
if e != nil {
@ -296,7 +297,7 @@ func TestReadSimpleJSONProtocolI64Null(t *testing.T) {
trans := NewTMemoryBuffer()
p := NewTSimpleJSONProtocol(trans)
trans.WriteString(value)
trans.Flush()
trans.Flush(context.Background())
s := trans.String()
v, e := p.ReadI64()
@ -318,7 +319,7 @@ func TestWriteSimpleJSONProtocolDouble(t *testing.T) {
if e := p.WriteDouble(value); e != nil {
t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error())
}
if e := p.Flush(); e != nil {
if e := p.Flush(context.Background()); e != nil {
t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error())
}
s := trans.String()
@ -355,7 +356,7 @@ func TestReadSimpleJSONProtocolDouble(t *testing.T) {
p := NewTSimpleJSONProtocol(trans)
n := NewNumericFromDouble(value)
trans.WriteString(n.String())
trans.Flush()
trans.Flush(context.Background())
s := trans.String()
v, e := p.ReadDouble()
if e != nil {
@ -394,7 +395,7 @@ func TestWriteSimpleJSONProtocolString(t *testing.T) {
if e := p.WriteString(value); e != nil {
t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error())
}
if e := p.Flush(); e != nil {
if e := p.Flush(context.Background()); e != nil {
t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error())
}
s := trans.String()
@ -416,7 +417,7 @@ func TestReadSimpleJSONProtocolString(t *testing.T) {
trans := NewTMemoryBuffer()
p := NewTSimpleJSONProtocol(trans)
trans.WriteString(jsonQuote(value))
trans.Flush()
trans.Flush(context.Background())
s := trans.String()
v, e := p.ReadString()
if e != nil {
@ -440,7 +441,7 @@ func TestReadSimpleJSONProtocolStringNull(t *testing.T) {
trans := NewTMemoryBuffer()
p := NewTSimpleJSONProtocol(trans)
trans.WriteString(value)
trans.Flush()
trans.Flush(context.Background())
s := trans.String()
v, e := p.ReadString()
if e != nil {
@ -464,7 +465,7 @@ func TestWriteSimpleJSONProtocolBinary(t *testing.T) {
if e := p.WriteBinary(value); e != nil {
t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error())
}
if e := p.Flush(); e != nil {
if e := p.Flush(context.Background()); e != nil {
t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error())
}
s := trans.String()
@ -487,7 +488,7 @@ func TestReadSimpleJSONProtocolBinary(t *testing.T) {
trans := NewTMemoryBuffer()
p := NewTSimpleJSONProtocol(trans)
trans.WriteString(jsonQuote(b64String))
trans.Flush()
trans.Flush(context.Background())
s := trans.String()
v, e := p.ReadBinary()
if e != nil {
@ -516,7 +517,7 @@ func TestReadSimpleJSONProtocolBinaryNull(t *testing.T) {
trans := NewTMemoryBuffer()
p := NewTSimpleJSONProtocol(trans)
trans.WriteString(value)
trans.Flush()
trans.Flush(context.Background())
s := trans.String()
b, e := p.ReadBinary()
v := string(b)
@ -542,7 +543,7 @@ func TestWriteSimpleJSONProtocolList(t *testing.T) {
}
}
p.WriteListEnd()
if e := p.Flush(); e != nil {
if e := p.Flush(context.Background()); e != nil {
t.Fatalf("Unable to write %s due to error flushing: %s", thetype, e.Error())
}
str := trans.String()
@ -596,7 +597,7 @@ func TestWriteSimpleJSONProtocolSet(t *testing.T) {
}
}
p.WriteSetEnd()
if e := p.Flush(); e != nil {
if e := p.Flush(context.Background()); e != nil {
t.Fatalf("Unable to write %s due to error flushing: %s", thetype, e.Error())
}
str := trans.String()
@ -653,12 +654,12 @@ func TestWriteSimpleJSONProtocolMap(t *testing.T) {
}
}
p.WriteMapEnd()
if e := p.Flush(); e != nil {
if e := p.Flush(context.Background()); e != nil {
t.Fatalf("Unable to write %s due to error flushing: %s", thetype, e.Error())
}
str := trans.String()
if str[0] != '[' || str[len(str)-1] != ']' {
t.Fatalf("Bad value for %s, wrote: %q, in go: %q", thetype, str, DOUBLE_VALUES)
t.Fatalf("Bad value for %s, wrote: %v, in go: %v", thetype, str, DOUBLE_VALUES)
}
l := strings.Split(str[1:len(str)-1], ",")
if len(l) < 3 {

View file

@ -23,11 +23,18 @@ import (
"log"
"runtime/debug"
"sync"
"sync/atomic"
)
// Simple, non-concurrent server for testing.
/*
* This is not a typical TSimpleServer as it is not blocked after accept a socket.
* It is more like a TThreadedServer that can handle different connections in different goroutines.
* This will work if golang user implements a conn-pool like thing in client side.
*/
type TSimpleServer struct {
quit chan struct{}
closed int32
wg sync.WaitGroup
mu sync.Mutex
processorFactory TProcessorFactory
serverTransport TServerTransport
@ -87,7 +94,6 @@ func NewTSimpleServerFactory6(processorFactory TProcessorFactory, serverTranspor
outputTransportFactory: outputTransportFactory,
inputProtocolFactory: inputProtocolFactory,
outputProtocolFactory: outputProtocolFactory,
quit: make(chan struct{}, 1),
}
}
@ -119,23 +125,37 @@ func (p *TSimpleServer) Listen() error {
return p.serverTransport.Listen()
}
func (p *TSimpleServer) innerAccept() (int32, error) {
client, err := p.serverTransport.Accept()
p.mu.Lock()
defer p.mu.Unlock()
closed := atomic.LoadInt32(&p.closed)
if closed != 0 {
return closed, nil
}
if err != nil {
return 0, err
}
if client != nil {
p.wg.Add(1)
go func() {
defer p.wg.Done()
if err := p.processRequests(client); err != nil {
log.Println("error processing request:", err)
}
}()
}
return 0, nil
}
func (p *TSimpleServer) AcceptLoop() error {
for {
client, err := p.serverTransport.Accept()
closed, err := p.innerAccept()
if err != nil {
select {
case <-p.quit:
return nil
default:
}
return err
}
if client != nil {
go func() {
if err := p.processRequests(client); err != nil {
log.Println("error processing request:", err)
}
}()
if closed != 0 {
return nil
}
}
}
@ -149,14 +169,15 @@ func (p *TSimpleServer) Serve() error {
return nil
}
var once sync.Once
func (p *TSimpleServer) Stop() error {
q := func() {
p.quit <- struct{}{}
p.serverTransport.Interrupt()
p.mu.Lock()
defer p.mu.Unlock()
if atomic.LoadInt32(&p.closed) != 0 {
return nil
}
once.Do(q)
atomic.StoreInt32(&p.closed, 1)
p.serverTransport.Interrupt()
p.wg.Wait()
return nil
}
@ -177,6 +198,7 @@ func (p *TSimpleServer) processRequests(client TTransport) error {
log.Printf("panic in processor: %s: %s", e, debug.Stack())
}
}()
if inputTransport != nil {
defer inputTransport.Close()
}
@ -184,17 +206,20 @@ func (p *TSimpleServer) processRequests(client TTransport) error {
defer outputTransport.Close()
}
for {
ok, err := processor.Process(inputProtocol, outputProtocol)
if atomic.LoadInt32(&p.closed) != 0 {
return nil
}
ok, err := processor.Process(defaultCtx, inputProtocol, outputProtocol)
if err, ok := err.(TTransportException); ok && err.TypeId() == END_OF_FILE {
return nil
} else if err != nil {
log.Printf("error processing request: %s", err)
return err
}
if err, ok := err.(TApplicationException); ok && err.TypeId() == UNKNOWN_METHOD {
continue
}
if !ok {
if !ok {
break
}
}

View file

@ -0,0 +1,156 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"testing"
"errors"
"runtime"
)
type mockServerTransport struct {
ListenFunc func() error
AcceptFunc func() (TTransport, error)
CloseFunc func() error
InterruptFunc func() error
}
func (m *mockServerTransport) Listen() error {
return m.ListenFunc()
}
func (m *mockServerTransport) Accept() (TTransport, error) {
return m.AcceptFunc()
}
func (m *mockServerTransport) Close() error {
return m.CloseFunc()
}
func (m *mockServerTransport) Interrupt() error {
return m.InterruptFunc()
}
type mockTTransport struct {
TTransport
}
func (m *mockTTransport) Close() error {
return nil
}
func TestMultipleStop(t *testing.T) {
proc := &mockProcessor{
ProcessFunc: func(in, out TProtocol) (bool, TException) {
return false, nil
},
}
var interruptCalled bool
c := make(chan struct{})
trans := &mockServerTransport{
ListenFunc: func() error {
return nil
},
AcceptFunc: func() (TTransport, error) {
<-c
return nil, nil
},
CloseFunc: func() error {
c <- struct{}{}
return nil
},
InterruptFunc: func() error {
interruptCalled = true
return nil
},
}
serv := NewTSimpleServer2(proc, trans)
go serv.Serve()
serv.Stop()
if !interruptCalled {
t.Error("first server transport should have been interrupted")
}
serv = NewTSimpleServer2(proc, trans)
interruptCalled = false
go serv.Serve()
serv.Stop()
if !interruptCalled {
t.Error("second server transport should have been interrupted")
}
}
func TestWaitRace(t *testing.T) {
proc := &mockProcessor{
ProcessFunc: func(in, out TProtocol) (bool, TException) {
return false, nil
},
}
trans := &mockServerTransport{
ListenFunc: func() error {
return nil
},
AcceptFunc: func() (TTransport, error) {
return &mockTTransport{}, nil
},
CloseFunc: func() error {
return nil
},
InterruptFunc: func() error {
return nil
},
}
serv := NewTSimpleServer2(proc, trans)
go serv.Serve()
runtime.Gosched()
serv.Stop()
}
func TestNoHangDuringStopFromDanglingLockAcquireDuringAcceptLoop(t *testing.T) {
proc := &mockProcessor{
ProcessFunc: func(in, out TProtocol) (bool, TException) {
return false, nil
},
}
trans := &mockServerTransport{
ListenFunc: func() error {
return nil
},
AcceptFunc: func() (TTransport, error) {
return nil, errors.New("no sir")
},
CloseFunc: func() error {
return nil
},
InterruptFunc: func() error {
return nil
},
}
serv := NewTSimpleServer2(proc, trans)
go serv.Serve()
runtime.Gosched()
serv.Stop()
}

View file

@ -20,6 +20,7 @@
package thrift
import (
"context"
"net"
"time"
)
@ -148,7 +149,7 @@ func (p *TSocket) Write(buf []byte) (int, error) {
return p.conn.Write(buf)
}
func (p *TSocket) Flush() error {
func (p *TSocket) Flush(ctx context.Context) error {
return nil
}

View file

@ -20,9 +20,9 @@
package thrift
import (
"crypto/tls"
"net"
"time"
"crypto/tls"
)
type TSSLServerSocket struct {
@ -38,6 +38,9 @@ func NewTSSLServerSocket(listenAddr string, cfg *tls.Config) (*TSSLServerSocket,
}
func NewTSSLServerSocketTimeout(listenAddr string, cfg *tls.Config, clientTimeout time.Duration) (*TSSLServerSocket, error) {
if cfg.MinVersion == 0 {
cfg.MinVersion = tls.VersionTLS10
}
addr, err := net.ResolveTCPAddr("tcp", listenAddr)
if err != nil {
return nil, err

View file

@ -20,6 +20,7 @@
package thrift
import (
"context"
"crypto/tls"
"net"
"time"
@ -48,6 +49,9 @@ func NewTSSLSocket(hostPort string, cfg *tls.Config) (*TSSLSocket, error) {
// NewTSSLSocketTimeout creates a net.Conn-backed TTransport, given a host and port
// it also accepts a tls Configuration and a timeout as a time.Duration
func NewTSSLSocketTimeout(hostPort string, cfg *tls.Config, timeout time.Duration) (*TSSLSocket, error) {
if cfg.MinVersion == 0 {
cfg.MinVersion = tls.VersionTLS10
}
return &TSSLSocket{hostPort: hostPort, timeout: timeout, cfg: cfg}, nil
}
@ -87,7 +91,8 @@ func (p *TSSLSocket) Open() error {
// If we have a hostname, we need to pass the hostname to tls.Dial for
// certificate hostname checks.
if p.hostPort != "" {
if p.conn, err = tls.Dial("tcp", p.hostPort, p.cfg); err != nil {
if p.conn, err = tls.DialWithDialer(&net.Dialer{
Timeout: p.timeout}, "tcp", p.hostPort, p.cfg); err != nil {
return NewTTransportException(NOT_OPEN, err.Error())
}
} else {
@ -103,7 +108,8 @@ func (p *TSSLSocket) Open() error {
if len(p.addr.String()) == 0 {
return NewTTransportException(NOT_OPEN, "Cannot open bad address.")
}
if p.conn, err = tls.Dial(p.addr.Network(), p.addr.String(), p.cfg); err != nil {
if p.conn, err = tls.DialWithDialer(&net.Dialer{
Timeout: p.timeout}, p.addr.Network(), p.addr.String(), p.cfg); err != nil {
return NewTTransportException(NOT_OPEN, err.Error())
}
}
@ -153,7 +159,7 @@ func (p *TSSLSocket) Write(buf []byte) (int, error) {
return p.conn.Write(buf)
}
func (p *TSSLSocket) Flush() error {
func (p *TSSLSocket) Flush(ctx context.Context) error {
return nil
}

View file

@ -20,6 +20,7 @@
package thrift
import (
"context"
"errors"
"io"
)
@ -30,6 +31,10 @@ type Flusher interface {
Flush() (err error)
}
type ContextFlusher interface {
Flush(ctx context.Context) (err error)
}
type ReadSizeProvider interface {
RemainingBytes() (num_bytes uint64)
}
@ -37,7 +42,7 @@ type ReadSizeProvider interface {
// Encapsulates the I/O layer
type TTransport interface {
io.ReadWriteCloser
Flusher
ContextFlusher
ReadSizeProvider
// Opens the transport for communication
@ -60,6 +65,6 @@ type TRichTransport interface {
io.ByteReader
io.ByteWriter
stringWriter
Flusher
ContextFlusher
ReadSizeProvider
}

View file

@ -20,6 +20,7 @@
package thrift
import (
"context"
"io"
"net"
"strconv"
@ -54,7 +55,7 @@ func TransportTest(t *testing.T, writeTrans TTransport, readTrans TTransport) {
if err != nil {
t.Fatalf("Transport %T cannot write binary data of length %d: %s", writeTrans, len(transport_bdata), err)
}
err = writeTrans.Flush()
err = writeTrans.Flush(context.Background())
if err != nil {
t.Fatalf("Transport %T cannot flush write of binary data: %s", writeTrans, err)
}
@ -74,7 +75,7 @@ func TransportTest(t *testing.T, writeTrans TTransport, readTrans TTransport) {
if err != nil {
t.Fatalf("Transport %T cannot write binary data 2 of length %d: %s", writeTrans, len(transport_bdata), err)
}
err = writeTrans.Flush()
err = writeTrans.Flush(context.Background())
if err != nil {
t.Fatalf("Transport %T cannot flush write binary data 2: %s", writeTrans, err)
}
@ -113,7 +114,7 @@ func TransportHeaderTest(t *testing.T, writeTrans TTransport, readTrans TTranspo
if err != nil {
t.Fatalf("Transport %T cannot write binary data of length %d: %s", writeTrans, len(transport_bdata), err)
}
err = writeTrans.Flush()
err = writeTrans.Flush(context.Background())
if err != nil {
t.Fatalf("Transport %T cannot flush write of binary data: %s", writeTrans, err)
}

View file

@ -21,13 +21,15 @@ package thrift
import (
"compress/zlib"
"context"
"io"
"log"
)
// TZlibTransportFactory is a factory for TZlibTransport instances
type TZlibTransportFactory struct {
level int
level int
factory TTransportFactory
}
// TZlibTransport is a TTransport implementation that makes use of zlib compression.
@ -39,12 +41,26 @@ type TZlibTransport struct {
// GetTransport constructs a new instance of NewTZlibTransport
func (p *TZlibTransportFactory) GetTransport(trans TTransport) (TTransport, error) {
if p.factory != nil {
// wrap other factory
var err error
trans, err = p.factory.GetTransport(trans)
if err != nil {
return nil, err
}
}
return NewTZlibTransport(trans, p.level)
}
// NewTZlibTransportFactory constructs a new instance of NewTZlibTransportFactory
func NewTZlibTransportFactory(level int) *TZlibTransportFactory {
return &TZlibTransportFactory{level: level}
return &TZlibTransportFactory{level: level, factory: nil}
}
// NewTZlibTransportFactory constructs a new instance of TZlibTransportFactory
// as a wrapper over existing transport factory
func NewTZlibTransportFactoryWithFactory(level int, factory TTransportFactory) *TZlibTransportFactory {
return &TZlibTransportFactory{level: level, factory: factory}
}
// NewTZlibTransport constructs a new instance of TZlibTransport
@ -76,11 +92,11 @@ func (z *TZlibTransport) Close() error {
}
// Flush flushes the writer and its underlying transport.
func (z *TZlibTransport) Flush() error {
func (z *TZlibTransport) Flush(ctx context.Context) error {
if err := z.writer.Flush(); err != nil {
return err
}
return z.transport.Flush()
return z.transport.Flush(ctx)
}
// IsOpen returns true if the transport is open

View file

@ -31,3 +31,32 @@ func TestZlibTransport(t *testing.T) {
}
TransportTest(t, trans, trans)
}
type DummyTransportFactory struct{}
func (p *DummyTransportFactory) GetTransport(trans TTransport) (TTransport, error) {
return NewTMemoryBuffer(), nil
}
func TestZlibFactoryTransportWithFactory(t *testing.T) {
factory := NewTZlibTransportFactoryWithFactory(
zlib.BestCompression,
&DummyTransportFactory{},
)
buffer := NewTMemoryBuffer()
trans, err := factory.GetTransport(buffer)
if err != nil {
t.Fatal(err)
}
TransportTest(t, trans, trans)
}
func TestZlibFactoryTransportWithoutFactory(t *testing.T) {
factory := NewTZlibTransportFactoryWithFactory(zlib.BestCompression, nil)
buffer := NewTMemoryBuffer()
trans, err := factory.GetTransport(buffer)
if err != nil {
t.Fatal(err)
}
TransportTest(t, trans, trans)
}