Upgrading dependency to Thrift 0.12.0

This commit is contained in:
Renan DelValle 2018-11-27 18:03:50 -08:00
parent 3e4590dcc0
commit 356978cb42
No known key found for this signature in database
GPG key ID: C240AD6D6F443EC9
1302 changed files with 101701 additions and 26784 deletions

View file

@ -33,7 +33,7 @@ set(haskell_sources
src/Thrift/Transport/HttpClient.hs
src/Thrift/Transport/IOBuffer.hs
src/Thrift/Types.hs
Thrift.cabal
thrift.cabal
)
if(BUILD_TESTING)

View file

@ -24,7 +24,7 @@ EXTRA_DIST = \
README.md \
Setup.lhs \
TODO \
Thrift.cabal \
thrift.cabal \
src \
test

View file

@ -18,8 +18,8 @@
--
Name: thrift
Version: 0.10.0
Cabal-Version: >= 1.8
Version: 1.0.0-dev
Cabal-Version: >= 1.24
License: OtherLicense
Category: Foreign
Build-Type: Simple
@ -40,7 +40,7 @@ Library
Hs-Source-Dirs:
src
Build-Depends:
base >= 4, base < 5, containers, ghc-prim, attoparsec, binary, bytestring >= 0.10, base64-bytestring, hashable, HTTP, text, unordered-containers >= 0.2.6, vector == 0.10.12.2, QuickCheck >= 2.8.2, split
base >= 4, base < 5, containers, ghc-prim, attoparsec, binary, bytestring >= 0.10, base64-bytestring, hashable, HTTP, text, hspec-core > 2.4.0, unordered-containers >= 0.2.6, vector >= 0.10.12.2, QuickCheck >= 2.8.2, split
if flag(network-uri)
build-depends: network-uri >= 2.6, network >= 2.6
else
@ -49,6 +49,7 @@ Library
Thrift,
Thrift.Arbitraries
Thrift.Protocol,
Thrift.Protocol.Header,
Thrift.Protocol.Binary,
Thrift.Protocol.Compact,
Thrift.Protocol.JSON,
@ -57,11 +58,14 @@ Library
Thrift.Transport.Empty,
Thrift.Transport.Framed,
Thrift.Transport.Handle,
Thrift.Transport.Header,
Thrift.Transport.HttpClient,
Thrift.Transport.IOBuffer,
Thrift.Transport.Memory,
Thrift.Types
Extensions:
Default-Language:
Haskell2010
Default-Extensions:
DeriveDataTypeable,
ExistentialQuantification,
FlexibleInstances,

View file

@ -90,13 +90,13 @@ data AppExn = AppExn { ae_type :: AppExnType, ae_message :: String }
deriving ( Show, Typeable )
instance Exception AppExn
writeAppExn :: (Protocol p, Transport t) => p t -> AppExn -> IO ()
writeAppExn :: Protocol p => p -> AppExn -> IO ()
writeAppExn pt ae = writeVal pt $ TStruct $ Map.fromList
[ (1, ("message", TString $ encodeUtf8 $ pack $ ae_message ae))
, (2, ("type", TI32 $ fromIntegral $ fromEnum (ae_type ae)))
]
readAppExn :: (Protocol p, Transport t) => p t -> IO AppExn
readAppExn :: Protocol p => p -> IO AppExn
readAppExn pt = do
let typemap = Map.fromList [(1,("message",T_STRING)),(2,("type",T_I32))]
TStruct fields <- readVal pt $ T_STRUCT typemap

View file

@ -22,12 +22,11 @@
module Thrift.Protocol
( Protocol(..)
, StatelessProtocol(..)
, ProtocolExn(..)
, ProtocolExnType(..)
, getTypeOf
, runParser
, versionMask
, version1
, bsToDouble
, bsToDoubleLE
) where
@ -35,7 +34,6 @@ module Thrift.Protocol
import Control.Exception
import Data.Attoparsec.ByteString
import Data.Bits
import Data.ByteString.Lazy (ByteString, toStrict)
import Data.ByteString.Unsafe
import Data.Functor ((<$>))
import Data.Int
@ -44,37 +42,26 @@ import Data.Text.Lazy (Text)
import Data.Typeable (Typeable)
import Data.Word
import Foreign.Ptr (castPtr)
import Foreign.Storable (Storable, peek, poke)
import Foreign.Storable (peek, poke)
import System.IO.Unsafe
import qualified Data.ByteString as BS
import qualified Data.HashMap.Strict as Map
import qualified Data.ByteString.Lazy as LBS
import Thrift.Types
import Thrift.Transport
versionMask :: Int32
versionMask = fromIntegral (0xffff0000 :: Word32)
version1 :: Int32
version1 = fromIntegral (0x80010000 :: Word32)
import Thrift.Types
class Protocol a where
getTransport :: Transport t => a t -> t
readByte :: a -> IO LBS.ByteString
readVal :: a -> ThriftType -> IO ThriftVal
readMessage :: a -> ((Text, MessageType, Int32) -> IO b) -> IO b
writeMessageBegin :: Transport t => a t -> (Text, MessageType, Int32) -> IO ()
writeMessageEnd :: Transport t => a t -> IO ()
writeMessageEnd _ = return ()
readMessageBegin :: Transport t => a t -> IO (Text, MessageType, Int32)
readMessageEnd :: Transport t => a t -> IO ()
readMessageEnd _ = return ()
writeVal :: a -> ThriftVal -> IO ()
writeMessage :: a -> (Text, MessageType, Int32) -> IO () -> IO ()
serializeVal :: Transport t => a t -> ThriftVal -> ByteString
deserializeVal :: Transport t => a t -> ThriftType -> ByteString -> ThriftVal
writeVal :: Transport t => a t -> ThriftVal -> IO ()
writeVal p = tWrite (getTransport p) . serializeVal p
readVal :: Transport t => a t -> ThriftType -> IO ThriftVal
class Protocol a => StatelessProtocol a where
serializeVal :: a -> ThriftVal -> LBS.ByteString
deserializeVal :: a -> ThriftType -> LBS.ByteString -> ThriftVal
data ProtocolExnType
= PE_UNKNOWN
@ -105,10 +92,10 @@ getTypeOf v = case v of
TBinary{} -> T_BINARY
TDouble{} -> T_DOUBLE
runParser :: (Protocol p, Transport t, Show a) => p t -> Parser a -> IO a
runParser :: (Protocol p, Show a) => p -> Parser a -> IO a
runParser prot p = refill >>= getResult . parse p
where
refill = handle handleEOF $ toStrict <$> tReadAll (getTransport prot) 1
refill = handle handleEOF $ LBS.toStrict <$> readByte prot
getResult (Done _ a) = return a
getResult (Partial k) = refill >>= getResult . k
getResult f = throw $ ProtocolExn PE_INVALID_DATA (show f)

View file

@ -25,6 +25,8 @@
module Thrift.Protocol.Binary
( module Thrift.Protocol
, BinaryProtocol(..)
, versionMask
, version1
) where
import Control.Exception ( throw )
@ -35,6 +37,7 @@ import Data.Functor
import Data.Int
import Data.Monoid
import Data.Text.Lazy.Encoding ( decodeUtf8, encodeUtf8 )
import Data.Word
import Thrift.Protocol
import Thrift.Transport
@ -47,37 +50,55 @@ import qualified Data.ByteString.Lazy as LBS
import qualified Data.HashMap.Strict as Map
import qualified Data.Text.Lazy as LT
data BinaryProtocol a = BinaryProtocol a
versionMask :: Int32
versionMask = fromIntegral (0xffff0000 :: Word32)
version1 :: Int32
version1 = fromIntegral (0x80010000 :: Word32)
data BinaryProtocol a = Transport a => BinaryProtocol a
getTransport :: Transport t => BinaryProtocol t -> t
getTransport (BinaryProtocol t) = t
-- NOTE: Reading and Writing functions rely on Builders and Data.Binary to
-- encode and decode data. Data.Binary assumes that the binary values it is
-- encoding to and decoding from are in BIG ENDIAN format, and converts the
-- endianness as necessary to match the local machine.
instance Protocol BinaryProtocol where
getTransport (BinaryProtocol t) = t
instance Transport t => Protocol (BinaryProtocol t) where
readByte p = tReadAll (getTransport p) 1
-- flushTransport p = tFlush (getTransport p)
writeMessage p (n, t, s) f = do
tWrite (getTransport p) messageBegin
f
tFlush $ getTransport p
where
messageBegin = toLazyByteString $
buildBinaryValue (TI32 (version1 .|. fromIntegral (fromEnum t))) <>
buildBinaryValue (TString $ encodeUtf8 n) <>
buildBinaryValue (TI32 s)
writeMessageBegin p (n, t, s) = tWrite (getTransport p) $ toLazyByteString $
buildBinaryValue (TI32 (version1 .|. fromIntegral (fromEnum t))) <>
buildBinaryValue (TString $ encodeUtf8 n) <>
buildBinaryValue (TI32 s)
readMessage p = (readMessageBegin p >>=)
where
readMessageBegin p = runParser p $ do
TI32 ver <- parseBinaryValue T_I32
if ver .&. versionMask /= version1
then throw $ ProtocolExn PE_BAD_VERSION "Missing version identifier"
else do
TString s <- parseBinaryValue T_STRING
TI32 sz <- parseBinaryValue T_I32
return (decodeUtf8 s, toEnum $ fromIntegral $ ver .&. 0xFF, sz)
readMessageBegin p = runParser p $ do
TI32 ver <- parseBinaryValue T_I32
if ver .&. versionMask /= version1
then throw $ ProtocolExn PE_BAD_VERSION "Missing version identifier"
else do
TString s <- parseBinaryValue T_STRING
TI32 sz <- parseBinaryValue T_I32
return (decodeUtf8 s, toEnum $ fromIntegral $ ver .&. 0xFF, sz)
writeVal p = tWrite (getTransport p) . toLazyByteString . buildBinaryValue
readVal p = runParser p . parseBinaryValue
instance Transport t => StatelessProtocol (BinaryProtocol t) where
serializeVal _ = toLazyByteString . buildBinaryValue
deserializeVal _ ty bs =
case LP.eitherResult $ LP.parse (parseBinaryValue ty) bs of
Left s -> error s
Right val -> val
readVal p = runParser p . parseBinaryValue
-- | Writing Functions
buildBinaryValue :: ThriftVal -> Builder
buildBinaryValue (TStruct fields) = buildBinaryStruct fields <> buildType T_STOP

View file

@ -25,10 +25,11 @@
module Thrift.Protocol.Compact
( module Thrift.Protocol
, CompactProtocol(..)
, parseVarint
, buildVarint
) where
import Control.Applicative
import Control.Exception ( throw )
import Control.Monad
import Data.Attoparsec.ByteString as P
import Data.Attoparsec.ByteString.Lazy as LP
@ -40,7 +41,7 @@ import Data.Monoid
import Data.Word
import Data.Text.Lazy.Encoding ( decodeUtf8, encodeUtf8 )
import Thrift.Protocol hiding (versionMask)
import Thrift.Protocol
import Thrift.Transport
import Thrift.Types
@ -64,38 +65,47 @@ typeBits = 0x07 -- 0000 0111
typeShiftAmount :: Int
typeShiftAmount = 5
getTransport :: Transport t => CompactProtocol t -> t
getTransport (CompactProtocol t) = t
instance Protocol CompactProtocol where
getTransport (CompactProtocol t) = t
instance Transport t => Protocol (CompactProtocol t) where
readByte p = tReadAll (getTransport p) 1
writeMessage p (n, t, s) f = do
tWrite (getTransport p) messageBegin
f
tFlush $ getTransport p
where
messageBegin = toLazyByteString $
B.word8 protocolID <>
B.word8 ((version .&. versionMask) .|.
(((fromIntegral $ fromEnum t) `shiftL`
typeShiftAmount) .&. typeMask)) <>
buildVarint (i32ToZigZag s) <>
buildCompactValue (TString $ encodeUtf8 n)
writeMessageBegin p (n, t, s) = tWrite (getTransport p) $ toLazyByteString $
B.word8 protocolID <>
B.word8 ((version .&. versionMask) .|.
(((fromIntegral $ fromEnum t) `shiftL`
typeShiftAmount) .&. typeMask)) <>
buildVarint (i32ToZigZag s) <>
buildCompactValue (TString $ encodeUtf8 n)
readMessageBegin p = runParser p $ do
pid <- fromIntegral <$> P.anyWord8
when (pid /= protocolID) $ error "Bad Protocol ID"
w <- fromIntegral <$> P.anyWord8
let ver = w .&. versionMask
when (ver /= version) $ error "Bad Protocol version"
let typ = (w `shiftR` typeShiftAmount) .&. typeBits
seqId <- parseVarint zigZagToI32
TString name <- parseCompactValue T_STRING
return (decodeUtf8 name, toEnum $ fromIntegral $ typ, seqId)
readMessage p f = readMessageBegin >>= f
where
readMessageBegin = runParser p $ do
pid <- fromIntegral <$> P.anyWord8
when (pid /= protocolID) $ error "Bad Protocol ID"
w <- fromIntegral <$> P.anyWord8
let ver = w .&. versionMask
when (ver /= version) $ error "Bad Protocol version"
let typ = (w `shiftR` typeShiftAmount) .&. typeBits
seqId <- parseVarint zigZagToI32
TString name <- parseCompactValue T_STRING
return (decodeUtf8 name, toEnum $ fromIntegral $ typ, seqId)
writeVal p = tWrite (getTransport p) . toLazyByteString . buildCompactValue
readVal p ty = runParser p $ parseCompactValue ty
instance Transport t => StatelessProtocol (CompactProtocol t) where
serializeVal _ = toLazyByteString . buildCompactValue
deserializeVal _ ty bs =
case LP.eitherResult $ LP.parse (parseCompactValue ty) bs of
Left s -> error s
Right val -> val
readVal p ty = runParser p $ parseCompactValue ty
-- | Writing Functions
buildCompactValue :: ThriftVal -> Builder
buildCompactValue (TStruct fields) = buildCompactStruct fields
@ -283,7 +293,7 @@ typeOf v = case v of
TSet{} -> 0x0A
TMap{} -> 0x0B
TStruct{} -> 0x0C
typeFrom :: Word8 -> ThriftType
typeFrom w = case w of
0x01 -> T_BOOL

View file

@ -0,0 +1,141 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one
-- or more contributor license agreements. See the NOTICE file
-- distributed with this work for additional information
-- regarding copyright ownership. The ASF licenses this file
-- to you under the Apache License, Version 2.0 (the
-- "License"); you may not use this file except in compliance
-- with the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing,
-- software distributed under the License is distributed on an
-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-- KIND, either express or implied. See the License for the
-- specific language governing permissions and limitations
-- under the License.
--
module Thrift.Protocol.Header
( module Thrift.Protocol
, HeaderProtocol(..)
, getProtocolType
, setProtocolType
, getHeaders
, getWriteHeaders
, setHeader
, setHeaders
, createHeaderProtocol
, createHeaderProtocol1
) where
import Thrift.Protocol
import Thrift.Protocol.Binary
import Thrift.Protocol.JSON
import Thrift.Protocol.Compact
import Thrift.Transport
import Thrift.Transport.Header
import Data.IORef
import qualified Data.Map as Map
data ProtocolWrap = forall a. (Protocol a) => ProtocolWrap(a)
instance Protocol ProtocolWrap where
readByte (ProtocolWrap p) = readByte p
readVal (ProtocolWrap p) = readVal p
readMessage (ProtocolWrap p) = readMessage p
writeVal (ProtocolWrap p) = writeVal p
writeMessage (ProtocolWrap p) = writeMessage p
data HeaderProtocol i o = (Transport i, Transport o) => HeaderProtocol {
trans :: HeaderTransport i o,
wrappedProto :: IORef ProtocolWrap
}
createProtocolWrap :: Transport t => ProtocolType -> t -> ProtocolWrap
createProtocolWrap typ t =
case typ of
TBinary -> ProtocolWrap $ BinaryProtocol t
TCompact -> ProtocolWrap $ CompactProtocol t
TJSON -> ProtocolWrap $ JSONProtocol t
createHeaderProtocol :: (Transport i, Transport o) => i -> o -> IO(HeaderProtocol i o)
createHeaderProtocol i o = do
t <- openHeaderTransport i o
pid <- readIORef $ protocolType t
proto <- newIORef $ createProtocolWrap pid t
return $ HeaderProtocol { trans = t, wrappedProto = proto }
createHeaderProtocol1 :: Transport t => t -> IO(HeaderProtocol t t)
createHeaderProtocol1 t = createHeaderProtocol t t
resetProtocol :: (Transport i, Transport o) => HeaderProtocol i o -> IO ()
resetProtocol p = do
pid <- readIORef $ protocolType $ trans p
writeIORef (wrappedProto p) $ createProtocolWrap pid $ trans p
getWrapped = readIORef . wrappedProto
setTransport :: (Transport i, Transport o) => HeaderProtocol i o -> HeaderTransport i o -> HeaderProtocol i o
setTransport p t = p { trans = t }
updateTransport :: (Transport i, Transport o) => HeaderProtocol i o -> (HeaderTransport i o -> HeaderTransport i o)-> HeaderProtocol i o
updateTransport p f = setTransport p (f $ trans p)
type Headers = Map.Map String String
-- TODO: we want to set headers without recreating client...
setHeader :: (Transport i, Transport o) => HeaderProtocol i o -> String -> String -> HeaderProtocol i o
setHeader p k v = updateTransport p $ \t -> t { writeHeaders = Map.insert k v $ writeHeaders t }
setHeaders :: (Transport i, Transport o) => HeaderProtocol i o -> Headers -> HeaderProtocol i o
setHeaders p h = updateTransport p $ \t -> t { writeHeaders = h }
-- TODO: make it public once we have first transform implementation for Haskell
setTransforms :: (Transport i, Transport o) => HeaderProtocol i o -> [TransformType] -> HeaderProtocol i o
setTransforms p trs = updateTransport p $ \t -> t { writeTransforms = trs }
setTransform :: (Transport i, Transport o) => HeaderProtocol i o -> TransformType -> HeaderProtocol i o
setTransform p tr = updateTransport p $ \t -> t { writeTransforms = tr:(writeTransforms t) }
getWriteHeaders :: (Transport i, Transport o) => HeaderProtocol i o -> Headers
getWriteHeaders = writeHeaders . trans
getHeaders :: (Transport i, Transport o) => HeaderProtocol i o -> IO [(String, String)]
getHeaders = readIORef . headers . trans
getProtocolType :: (Transport i, Transport o) => HeaderProtocol i o -> IO ProtocolType
getProtocolType p = readIORef $ protocolType $ trans p
setProtocolType :: (Transport i, Transport o) => HeaderProtocol i o -> ProtocolType -> IO ()
setProtocolType p typ = do
typ0 <- getProtocolType p
if typ == typ0
then return ()
else do
tSetProtocol (trans p) typ
resetProtocol p
instance (Transport i, Transport o) => Protocol (HeaderProtocol i o) where
readByte p = tReadAll (trans p) 1
readVal p tp = do
proto <- getWrapped p
readVal proto tp
readMessage p f = do
tResetProtocol (trans p)
resetProtocol p
proto <- getWrapped p
readMessage proto f
writeVal p v = do
proto <- getWrapped p
writeVal proto v
writeMessage p x f = do
proto <- getWrapped p
writeMessage proto x f

View file

@ -29,12 +29,12 @@ module Thrift.Protocol.JSON
) where
import Control.Applicative
import Control.Exception (bracket)
import Control.Monad
import Data.Attoparsec.ByteString as P
import Data.Attoparsec.ByteString.Char8 as PC
import Data.Attoparsec.ByteString.Lazy as LP
import Data.ByteString.Base64.Lazy as B64C
import Data.ByteString.Base64 as B64
import Data.ByteString.Lazy.Builder as B
import Data.ByteString.Internal (c2w, w2c)
import Data.Functor
@ -58,38 +58,48 @@ import qualified Data.Text.Lazy as LT
-- encoded as a JSON 'ByteString'
data JSONProtocol t = JSONProtocol t
-- ^ Construct a 'JSONProtocol' with a 'Transport'
getTransport :: Transport t => JSONProtocol t -> t
getTransport (JSONProtocol t) = t
instance Protocol JSONProtocol where
getTransport (JSONProtocol t) = t
instance Transport t => Protocol (JSONProtocol t) where
readByte p = tReadAll (getTransport p) 1
writeMessageBegin (JSONProtocol t) (s, ty, sq) = tWrite t $ toLazyByteString $
B.char8 '[' <> buildShowable (1 :: Int32) <>
B.string8 ",\"" <> escape (encodeUtf8 s) <> B.char8 '\"' <>
B.char8 ',' <> buildShowable (fromEnum ty) <>
B.char8 ',' <> buildShowable sq <>
B.char8 ','
writeMessageEnd (JSONProtocol t) = tWrite t "]"
readMessageBegin p = runParser p $ skipSpace *> do
_ver :: Int32 <- lexeme (PC.char8 '[') *> lexeme (signed decimal)
bs <- lexeme (PC.char8 ',') *> lexeme escapedString
case decodeUtf8' bs of
Left _ -> fail "readMessage: invalid text encoding"
Right str -> do
ty <- toEnum <$> (lexeme (PC.char8 ',') *> lexeme (signed decimal))
seqNum <- lexeme (PC.char8 ',') *> lexeme (signed decimal)
_ <- PC.char8 ','
return (str, ty, seqNum)
readMessageEnd p = void $ runParser p (PC.char8 ']')
writeMessage (JSONProtocol t) (s, ty, sq) = bracket readMessageBegin readMessageEnd . const
where
readMessageBegin = tWrite t $ toLazyByteString $
B.char8 '[' <> buildShowable (1 :: Int32) <>
B.string8 ",\"" <> escape (encodeUtf8 s) <> B.char8 '\"' <>
B.char8 ',' <> buildShowable (fromEnum ty) <>
B.char8 ',' <> buildShowable sq <>
B.char8 ','
readMessageEnd _ = do
tWrite t "]"
tFlush t
readMessage p = bracket readMessageBegin readMessageEnd
where
readMessageBegin = runParser p $ skipSpace *> do
_ver :: Int32 <- lexeme (PC.char8 '[') *> lexeme (signed decimal)
bs <- lexeme (PC.char8 ',') *> lexeme escapedString
case decodeUtf8' bs of
Left _ -> fail "readMessage: invalid text encoding"
Right str -> do
ty <- toEnum <$> (lexeme (PC.char8 ',') *> lexeme (signed decimal))
seqNum <- lexeme (PC.char8 ',') *> lexeme (signed decimal)
_ <- PC.char8 ','
return (str, ty, seqNum)
readMessageEnd _ = void $ runParser p (PC.char8 ']')
writeVal p = tWrite (getTransport p) . toLazyByteString . buildJSONValue
readVal p ty = runParser p $ skipSpace *> parseJSONValue ty
instance Transport t => StatelessProtocol (JSONProtocol t) where
serializeVal _ = toLazyByteString . buildJSONValue
deserializeVal _ ty bs =
case LP.eitherResult $ LP.parse (parseJSONValue ty) bs of
Left s -> error s
Right val -> val
readVal p ty = runParser p $ skipSpace *> parseJSONValue ty
-- Writing Functions
buildJSONValue :: ThriftVal -> Builder

View file

@ -38,10 +38,10 @@ import Thrift.Protocol.Binary
-- | A threaded sever that is capable of using any Transport or Protocol
-- instances.
runThreadedServer :: (Transport t, Protocol i, Protocol o)
=> (Socket -> IO (i t, o t))
runThreadedServer :: (Protocol i, Protocol o)
=> (Socket -> IO (i, o))
-> h
-> (h -> (i t, o t) -> IO Bool)
-> (h -> (i, o) -> IO Bool)
-> PortID
-> IO a
runThreadedServer accepter hand proc_ port = do

View file

@ -44,7 +44,13 @@ import Data.Monoid
instance Transport Handle where
tIsOpen = hIsOpen
tClose = hClose
tRead h n = LBS.hGet h n `Control.Exception.catch` handleEOF mempty
tRead h n = read `Control.Exception.catch` handleEOF mempty
where
read = do
hLookAhead h
LBS.hGetNonBlocking h n
tReadAll _ 0 = return mempty
tReadAll h n = LBS.hGet h n `Control.Exception.catch` throwTransportExn
tPeek h = (Just . c2w <$> hLookAhead h) `Control.Exception.catch` handleEOF Nothing
tWrite = LBS.hPut
tFlush = hFlush
@ -61,8 +67,12 @@ instance HandleSource FilePath where
instance HandleSource (HostName, PortID) where
hOpen = uncurry connectTo
throwTransportExn :: IOError -> IO a
throwTransportExn e = if isEOFError e
then throw $ TransportExn "Cannot read. Remote side has closed." TE_UNKNOWN
else throw $ TransportExn "Handle tReadAll: Could not read" TE_UNKNOWN
handleEOF :: a -> IOError -> IO a
handleEOF a e = if isEOFError e
then return a
else throw $ TransportExn "TChannelTransport: Could not read" TE_UNKNOWN
else throw $ TransportExn "Handle: Could not read" TE_UNKNOWN

View file

@ -0,0 +1,354 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one
-- or more contributor license agreements. See the NOTICE file
-- distributed with this work for additional information
-- regarding copyright ownership. The ASF licenses this file
-- to you under the Apache License, Version 2.0 (the
-- "License"); you may not use this file except in compliance
-- with the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing,
-- software distributed under the License is distributed on an
-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-- KIND, either express or implied. See the License for the
-- specific language governing permissions and limitations
-- under the License.
--
module Thrift.Transport.Header
( module Thrift.Transport
, HeaderTransport(..)
, openHeaderTransport
, ProtocolType(..)
, TransformType(..)
, ClientType(..)
, tResetProtocol
, tSetProtocol
) where
import Thrift.Transport
import Thrift.Protocol.Compact
import Control.Applicative
import Control.Exception ( throw )
import Control.Monad
import Data.Bits
import Data.IORef
import Data.Int
import Data.Monoid
import Data.Word
import qualified Data.Attoparsec.ByteString as P
import qualified Data.Binary as Binary
import qualified Data.ByteString as BS
import qualified Data.ByteString.Char8 as C
import qualified Data.ByteString.Lazy as LBS
import qualified Data.ByteString.Lazy.Builder as B
import qualified Data.Map as Map
data ProtocolType = TBinary | TCompact | TJSON deriving (Enum, Eq)
data ClientType = HeaderClient | Framed | Unframed deriving (Enum, Eq)
infoIdKeyValue = 1
type Headers = Map.Map String String
data TransformType = ZlibTransform deriving (Enum, Eq)
fromTransportType :: TransformType -> Int16
fromTransportType ZlibTransform = 1
toTransportType :: Int16 -> TransformType
toTransportType 1 = ZlibTransform
toTransportType _ = throw $ TransportExn "HeaderTransport: Unknown transform ID" TE_UNKNOWN
data HeaderTransport i o = (Transport i, Transport o) => HeaderTransport
{ readBuffer :: IORef LBS.ByteString
, writeBuffer :: IORef B.Builder
, inTrans :: i
, outTrans :: o
, clientType :: IORef ClientType
, protocolType :: IORef ProtocolType
, headers :: IORef [(String, String)]
, writeHeaders :: Headers
, transforms :: IORef [TransformType]
, writeTransforms :: [TransformType]
}
openHeaderTransport :: (Transport i, Transport o) => i -> o -> IO (HeaderTransport i o)
openHeaderTransport i o = do
pid <- newIORef TCompact
rBuf <- newIORef LBS.empty
wBuf <- newIORef mempty
cType <- newIORef HeaderClient
h <- newIORef []
trans <- newIORef []
return HeaderTransport
{ readBuffer = rBuf
, writeBuffer = wBuf
, inTrans = i
, outTrans = o
, clientType = cType
, protocolType = pid
, headers = h
, writeHeaders = Map.empty
, transforms = trans
, writeTransforms = []
}
isFramed t = (/= Unframed) <$> readIORef (clientType t)
readFrame :: (Transport i, Transport o) => HeaderTransport i o -> IO Bool
readFrame t = do
let input = inTrans t
let rBuf = readBuffer t
let cType = clientType t
lsz <- tRead input 4
let sz = LBS.toStrict lsz
case P.parseOnly P.endOfInput sz of
Right _ -> do return False
Left _ -> do
case parseBinaryMagic sz of
Right _ -> do
writeIORef rBuf $ lsz
writeIORef cType Unframed
writeIORef (protocolType t) TBinary
return True
Left _ -> do
case parseCompactMagic sz of
Right _ -> do
writeIORef rBuf $ lsz
writeIORef cType Unframed
writeIORef (protocolType t) TCompact
return True
Left _ -> do
let len = Binary.decode lsz :: Int32
lbuf <- tReadAll input $ fromIntegral len
let buf = LBS.toStrict lbuf
case parseBinaryMagic buf of
Right _ -> do
writeIORef cType Framed
writeIORef (protocolType t) TBinary
writeIORef rBuf lbuf
return True
Left _ -> do
case parseCompactMagic buf of
Right _ -> do
writeIORef cType Framed
writeIORef (protocolType t) TCompact
writeIORef rBuf lbuf
return True
Left _ -> do
case parseHeaderMagic buf of
Right flags -> do
let (flags, seqNum, header, body) = extractHeader buf
writeIORef cType HeaderClient
handleHeader t header
payload <- untransform t body
writeIORef rBuf $ LBS.fromStrict $ payload
return True
Left _ ->
throw $ TransportExn "HeaderTransport: unkonwn client type" TE_UNKNOWN
parseBinaryMagic = P.parseOnly $ P.word8 0x80 *> P.word8 0x01 *> P.word8 0x00 *> P.anyWord8
parseCompactMagic = P.parseOnly $ P.word8 0x82 *> P.satisfy (\b -> b .&. 0x1f == 0x01)
parseHeaderMagic = P.parseOnly $ P.word8 0x0f *> P.word8 0xff *> (P.count 2 P.anyWord8)
parseI32 :: P.Parser Int32
parseI32 = Binary.decode . LBS.fromStrict <$> P.take 4
parseI16 :: P.Parser Int16
parseI16 = Binary.decode . LBS.fromStrict <$> P.take 2
extractHeader :: BS.ByteString -> (Int16, Int32, BS.ByteString, BS.ByteString)
extractHeader bs =
case P.parse extractHeader_ bs of
P.Done remain (flags, seqNum, header) -> (flags, seqNum, header, remain)
_ -> throw $ TransportExn "HeaderTransport: Invalid header" TE_UNKNOWN
where
extractHeader_ = do
magic <- P.word8 0x0f *> P.word8 0xff
flags <- parseI16
seqNum <- parseI32
(headerSize :: Int) <- (* 4) . fromIntegral <$> parseI16
header <- P.take headerSize
return (flags, seqNum, header)
handleHeader t header =
case P.parseOnly parseHeader header of
Right (pType, trans, info) -> do
writeIORef (protocolType t) pType
writeIORef (transforms t) trans
writeIORef (headers t) info
_ -> throw $ TransportExn "HeaderTransport: Invalid header" TE_UNKNOWN
iw16 :: Int16 -> Word16
iw16 = fromIntegral
iw32 :: Int32 -> Word32
iw32 = fromIntegral
wi16 :: Word16 -> Int16
wi16 = fromIntegral
wi32 :: Word32 -> Int32
wi32 = fromIntegral
parseHeader :: P.Parser (ProtocolType, [TransformType], [(String, String)])
parseHeader = do
protocolType <- toProtocolType <$> parseVarint wi16
numTrans <- fromIntegral <$> parseVarint wi16
trans <- replicateM numTrans parseTransform
info <- parseInfo
return (protocolType, trans, info)
toProtocolType :: Int16 -> ProtocolType
toProtocolType 0 = TBinary
toProtocolType 1 = TJSON
toProtocolType 2 = TCompact
fromProtocolType :: ProtocolType -> Int16
fromProtocolType TBinary = 0
fromProtocolType TJSON = 1
fromProtocolType TCompact = 2
parseTransform :: P.Parser TransformType
parseTransform = toTransportType <$> parseVarint wi16
parseInfo :: P.Parser [(String, String)]
parseInfo = do
n <- P.eitherP P.endOfInput (parseVarint wi32)
case n of
Left _ -> return []
Right n0 ->
replicateM (fromIntegral n0) $ do
klen <- parseVarint wi16
k <- P.take $ fromIntegral klen
vlen <- parseVarint wi16
v <- P.take $ fromIntegral vlen
return (C.unpack k, C.unpack v)
parseString :: P.Parser BS.ByteString
parseString = parseVarint wi32 >>= (P.take . fromIntegral)
buildHeader :: HeaderTransport i o -> IO B.Builder
buildHeader t = do
pType <- readIORef $ protocolType t
let pId = buildVarint $ iw16 $ fromProtocolType pType
let headerContent = pId <> (buildTransforms t) <> (buildInfo t)
let len = fromIntegral $ LBS.length $ B.toLazyByteString headerContent
-- TODO: length limit check
let padding = mconcat $ replicate (mod len 4) $ B.word8 0
let codedLen = B.int16BE (fromIntegral $ (quot (len - 1) 4) + 1)
let flags = 0
let seqNum = 0
return $ B.int16BE 0x0fff <> B.int16BE flags <> B.int32BE seqNum <> codedLen <> headerContent <> padding
buildTransforms :: HeaderTransport i o -> B.Builder
-- TODO: check length limit
buildTransforms t =
let trans = writeTransforms t in
(buildVarint $ iw16 $ fromIntegral $ length trans) <>
(mconcat $ map (buildVarint . iw16 . fromTransportType) trans)
buildInfo :: HeaderTransport i o -> B.Builder
buildInfo t =
let h = Map.assocs $ writeHeaders t in
-- TODO: check length limit
case length h of
0 -> mempty
len -> (buildVarint $ iw16 $ fromIntegral $ len) <> (mconcat $ map buildInfoEntry h)
where
buildInfoEntry (k, v) = buildVarStr k <> buildVarStr v
-- TODO: check length limit
buildVarStr s = (buildVarint $ iw16 $ fromIntegral $ length s) <> B.string8 s
tResetProtocol :: (Transport i, Transport o) => HeaderTransport i o -> IO Bool
tResetProtocol t = do
rBuf <- readIORef $ readBuffer t
writeIORef (clientType t) HeaderClient
readFrame t
tSetProtocol :: (Transport i, Transport o) => HeaderTransport i o -> ProtocolType -> IO ()
tSetProtocol t = writeIORef (protocolType t)
transform :: HeaderTransport i o -> LBS.ByteString -> LBS.ByteString
transform t bs =
foldr applyTransform bs $ writeTransforms t
where
-- applyTransform bs ZlibTransform =
-- throw $ TransportExn "HeaderTransport: not implemented: ZlibTransform " TE_UNKNOWN
applyTransform bs _ =
throw $ TransportExn "HeaderTransport: Unknown transform" TE_UNKNOWN
untransform :: HeaderTransport i o -> BS.ByteString -> IO BS.ByteString
untransform t bs = do
trans <- readIORef $ transforms t
return $ foldl unapplyTransform bs trans
where
-- unapplyTransform bs ZlibTransform =
-- throw $ TransportExn "HeaderTransport: not implemented: ZlibTransform " TE_UNKNOWN
unapplyTransform bs _ =
throw $ TransportExn "HeaderTransport: Unknown transform" TE_UNKNOWN
instance (Transport i, Transport o) => Transport (HeaderTransport i o) where
tIsOpen t = do
tIsOpen (inTrans t)
tIsOpen (outTrans t)
tClose t = do
tClose(outTrans t)
tClose(inTrans t)
tRead t len = do
rBuf <- readIORef $ readBuffer t
if not $ LBS.null rBuf
then do
let (consumed, remain) = LBS.splitAt (fromIntegral len) rBuf
writeIORef (readBuffer t) remain
return consumed
else do
framed <- isFramed t
if not framed
then tRead (inTrans t) len
else do
ok <- readFrame t
if ok
then tRead t len
else return LBS.empty
tPeek t = do
rBuf <- readIORef (readBuffer t)
if not $ LBS.null rBuf
then return $ Just $ LBS.head rBuf
else do
framed <- isFramed t
if not framed
then tPeek (inTrans t)
else do
ok <- readFrame t
if ok
then tPeek t
else return Nothing
tWrite t buf = do
let wBuf = writeBuffer t
framed <- isFramed t
if framed
then modifyIORef wBuf (<> B.lazyByteString buf)
else
-- TODO: what should we do when switched to unframed in the middle ?
tWrite(outTrans t) buf
tFlush t = do
cType <- readIORef $ clientType t
case cType of
Unframed -> tFlush $ outTrans t
Framed -> flushBuffer t id mempty
HeaderClient -> buildHeader t >>= flushBuffer t (transform t)
where
flushBuffer t f header = do
wBuf <- readIORef $ writeBuffer t
writeIORef (writeBuffer t) mempty
let payload = B.toLazyByteString (header <> wBuf)
tWrite (outTrans t) $ Binary.encode (fromIntegral $ LBS.length payload :: Int32)
tWrite (outTrans t) $ f payload
tFlush (outTrans t)