--
-- Licensed to the Apache Software Foundation (ASF) under one
-- or more contributor license agreements. See the NOTICE file
-- distributed with this work for additional information
-- regarding copyright ownership. The ASF licenses this file
-- to you under the Apache License, Version 2.0 (the
-- "License"), you may not use this file except in compliance
-- with the License. You may obtain a copy of the License at
--
--   http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing,
-- software distributed under the License is distributed on an
-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-- KIND, either express or implied. See the License for the
-- specific language governing permissions and limitations
-- under the License.
--

require 'TProtocol'
require 'libluabpack'
require 'libluabitwise'

TJSONProtocol = __TObject.new(TProtocolBase, {
  __type = 'TJSONProtocol',
  THRIFT_JSON_PROTOCOL_VERSION = 1,
  jsonContext = {},
  jsonContextVal = {first = true, colon = true, ttype = 2, null = true},
  jsonContextIndex = 1,
  hasReadByte = ""
})

TTypeToString = {}
TTypeToString[TType.BOOL]   = "tf"
TTypeToString[TType.BYTE]   = "i8"
TTypeToString[TType.I16]    = "i16"
TTypeToString[TType.I32]    = "i32"
TTypeToString[TType.I64]    = "i64"
TTypeToString[TType.DOUBLE] = "dbl"
TTypeToString[TType.STRING] = "str"
TTypeToString[TType.STRUCT] = "rec"
TTypeToString[TType.LIST]   = "lst"
TTypeToString[TType.SET]    = "set"
TTypeToString[TType.MAP]    = "map"

StringToTType = {
  tf  = TType.BOOL,
  i8  = TType.BYTE,
  i16 = TType.I16,
  i32 = TType.I32,
  i64 = TType.I64,
  dbl = TType.DOUBLE,
  str = TType.STRING,
  rec = TType.STRUCT,
  map = TType.MAP,
  set = TType.SET,
  lst = TType.LIST
}

JSONNode = {
  ObjectBegin = '{',
  ObjectEnd = '}',
  ArrayBegin = '[',
  ArrayEnd = ']',
  PairSeparator = ':',
  ElemSeparator = ',',
  Backslash = '\\',
  StringDelimiter = '"',
  ZeroChar = '0',
  EscapeChar = 'u',
  Nan = 'NaN',
  Infinity = 'Infinity',
  NegativeInfinity = '-Infinity',
  EscapeChars = "\"\\bfnrt",
  EscapePrefix = "\\u00"
}

EscapeCharVals = {
  '"', '\\', '\b', '\f', '\n', '\r', '\t'
}

JSONCharTable = {
  --0   1   2   3   4   5   6   7   8   9   A   B   C   D   E   F
    0,  0,  0,  0,  0,  0,  0,  0, 98,116,110,  0,102,114,  0,  0,
    0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
    1,  1,34,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
}

-- character table string
local b='ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/'

-- encoding
function base64_encode(data)
    return ((data:gsub('.', function(x) 
        local r,b='',x:byte()
        for i=8,1,-1 do r=r..(b%2^i-b%2^(i-1)>0 and '1' or '0') end
        return r;
    end)..'0000'):gsub('%d%d%d?%d?%d?%d?', function(x)
        if (#x < 6) then return '' end
        local c=0
        for i=1,6 do c=c+(x:sub(i,i)=='1' and 2^(6-i) or 0) end
        return b:sub(c+1,c+1)
    end)..({ '', '==', '=' })[#data%3+1])
end

-- decoding
function base64_decode(data)
    data = string.gsub(data, '[^'..b..'=]', '')
    return (data:gsub('.', function(x)
        if (x == '=') then return '' end
        local r,f='',(b:find(x)-1)
        for i=6,1,-1 do r=r..(f%2^i-f%2^(i-1)>0 and '1' or '0') end
        return r;
    end):gsub('%d%d%d?%d?%d?%d?%d?%d?', function(x)
        if (#x ~= 8) then return '' end
        local c=0
        for i=1,8 do c=c+(x:sub(i,i)=='1' and 2^(8-i) or 0) end
        return string.char(c)
    end))
end

function TJSONProtocol:resetContext()
  self.jsonContext = {}
  self.jsonContextVal = {first = true, colon = true, ttype = 2, null = true}
  self.jsonContextIndex = 1
end

function TJSONProtocol:contextPush(context)
  self.jsonContextIndex = self.jsonContextIndex + 1
  self.jsonContext[self.jsonContextIndex] = self.jsonContextVal
  self.jsonContextVal = context
end

function TJSONProtocol:contextPop()
  self.jsonContextVal = self.jsonContext[self.jsonContextIndex]
  self.jsonContextIndex = self.jsonContextIndex - 1
end

function TJSONProtocol:escapeNum()
  if self.jsonContextVal.ttype == 1 then
    return self.jsonContextVal.colon
  else
    return false
  end
end

function TJSONProtocol:writeElemSeparator()
  if self.jsonContextVal.null then
    return
  end
  if self.jsonContextVal.first then
    self.jsonContextVal.first = false
  else
    if self.jsonContextVal.ttype == 1 then
      if self.jsonContextVal.colon then
        self.trans:write(JSONNode.PairSeparator)
        self.jsonContextVal.colon = false
      else
        self.trans:write(JSONNode.ElemSeparator)
        self.jsonContextVal.colon = true
      end
    else
      self.trans:write(JSONNode.ElemSeparator)
    end
  end
end

function TJSONProtocol:hexChar(val)
  val = libluabitwise.band(val, 0x0f)
  if val < 10 then
    return val + 48
  else
    return val + 87
  end
end

function TJSONProtocol:writeJSONEscapeChar(ch)
  self.trans:write(JSONNode.EscapePrefix)
  local outCh = hexChar(libluabitwise.shiftr(ch, 4))
  local buff = libluabpack.bpack('c', outCh)
  self.trans:write(buff)
  outCh = hexChar(ch)
  buff = libluabpack.bpack('c', outCh)
  self.trans:write(buff)
end

function TJSONProtocol:writeJSONChar(byte)
  ch = string.byte(byte)
  if ch >= 0x30 then
    if ch == JSONNode.Backslash then
      self.trans:write(JSONNode.Backslash)
      self.trans:write(JSONNode.Backslash)
    else
      self.trans:write(byte)
    end
  else
    local outCh = JSONCharTable[ch+1]
    if outCh == 1 then
      self.trans:write(byte)
    elseif outCh > 1 then
      self.trans:write(JSONNode.Backslash)
      local buff = libluabpack.bpack('c', outCh)
      self.trans:write(buff)
    else
      self:writeJSONEscapeChar(ch)
    end
  end
end

function TJSONProtocol:writeJSONString(str)
  self:writeElemSeparator()
  self.trans:write(JSONNode.StringDelimiter)
  -- TODO escape special characters
  local length = string.len(str)
  local ii = 1
  while ii <= length do
    self:writeJSONChar(string.sub(str, ii, ii))
    ii = ii + 1
  end
  self.trans:write(JSONNode.StringDelimiter)
end

function TJSONProtocol:writeJSONBase64(str)
  self:writeElemSeparator()
  self.trans:write(JSONNode.StringDelimiter)
  local length = string.len(str)
  local offset = 1
  while length >= 3 do
    -- Encode 3 bytes at a time
    local bytes = base64_encode(string.sub(str, offset, offset+3))
    self.trans:write(bytes)
    length = length - 3
    offset = offset + 3
  end
  if length > 0 then
    local bytes = base64_encode(string.sub(str, offset, offset+length))
    self.trans:write(bytes)
  end
  self.trans:write(JSONNode.StringDelimiter)
end

function TJSONProtocol:writeJSONInteger(num)
  self:writeElemSeparator()
  if self:escapeNum() then
    self.trans:write(JSONNode.StringDelimiter)
  end
  local numstr = "" .. num
  numstr = string.sub(numstr, string.find(numstr, "^[+-]?%d+"))
  self.trans:write(numstr)
  if self:escapeNum() then
    self.trans:write(JSONNode.StringDelimiter)
  end
end

function TJSONProtocol:writeJSONDouble(dub)
  self:writeElemSeparator()
  local val = "" .. dub
  local prefix = string.sub(val, 1, 1)
  local special = false
  if prefix == 'N' or prefix == 'n' then
    val = JSONNode.Nan
    special = true
  elseif prefix == 'I' or prefix == 'i' then
    val = JSONNode.Infinity
    special = true
  elseif prefix == '-' then
    local secondByte = string.sub(val, 2, 2)
    if secondByte == 'I' or secondByte == 'i' then
      val = JSONNode.NegativeInfinity
      special = true
    end
  end

  if special or self:escapeNum() then
    self.trans:write(JSONNode.StringDelimiter)
  end
  self.trans:write(val)
  if special or self:escapeNum() then
    self.trans:write(JSONNode.StringDelimiter)
  end
end

function TJSONProtocol:writeJSONObjectBegin()
  self:writeElemSeparator()
  self.trans:write(JSONNode.ObjectBegin)
  self:contextPush({first = true, colon = true, ttype = 1, null = false})
end

function TJSONProtocol:writeJSONObjectEnd()
  self:contextPop()
  self.trans:write(JSONNode.ObjectEnd)
end

function TJSONProtocol:writeJSONArrayBegin()
  self:writeElemSeparator()
  self.trans:write(JSONNode.ArrayBegin)
  self:contextPush({first = true, colon = true, ttype = 2, null = false})
end

function TJSONProtocol:writeJSONArrayEnd()
  self:contextPop()
  self.trans:write(JSONNode.ArrayEnd)
end

function TJSONProtocol:writeMessageBegin(name, ttype, seqid)
  self:resetContext()
  self:writeJSONArrayBegin()
  self:writeJSONInteger(TJSONProtocol.THRIFT_JSON_PROTOCOL_VERSION)
  self:writeJSONString(name)
  self:writeJSONInteger(ttype)
  self:writeJSONInteger(seqid)
end

function TJSONProtocol:writeMessageEnd()
  self:writeJSONArrayEnd()
end

function TJSONProtocol:writeStructBegin(name)
  self:writeJSONObjectBegin()
end

function TJSONProtocol:writeStructEnd()
  self:writeJSONObjectEnd()
end

function TJSONProtocol:writeFieldBegin(name, ttype, id)
  self:writeJSONInteger(id)
  self:writeJSONObjectBegin()
  self:writeJSONString(TTypeToString[ttype])
end

function TJSONProtocol:writeFieldEnd()
  self:writeJSONObjectEnd()
end

function TJSONProtocol:writeFieldStop()
end

function TJSONProtocol:writeMapBegin(ktype, vtype, size)
  self:writeJSONArrayBegin()
  self:writeJSONString(TTypeToString[ktype])
  self:writeJSONString(TTypeToString[vtype])
  self:writeJSONInteger(size)
  return self:writeJSONObjectBegin()
end

function TJSONProtocol:writeMapEnd()
  self:writeJSONObjectEnd()
  self:writeJSONArrayEnd()
end

function TJSONProtocol:writeListBegin(etype, size)
  self:writeJSONArrayBegin()
  self:writeJSONString(TTypeToString[etype])
  self:writeJSONInteger(size)
end

function TJSONProtocol:writeListEnd()
  self:writeJSONArrayEnd()
end

function TJSONProtocol:writeSetBegin(etype, size)
  self:writeJSONArrayBegin()
  self:writeJSONString(TTypeToString[etype])
  self:writeJSONInteger(size)
end

function TJSONProtocol:writeSetEnd()
  self:writeJSONArrayEnd()
end

function TJSONProtocol:writeBool(bool)
  if bool then
    self:writeJSONInteger(1)
  else
    self:writeJSONInteger(0)
  end
end

function TJSONProtocol:writeByte(byte)
  local buff = libluabpack.bpack('c', byte)
  local val = libluabpack.bunpack('c', buff)
  self:writeJSONInteger(val)
end

function TJSONProtocol:writeI16(i16)
  local buff = libluabpack.bpack('s', i16)
  local val = libluabpack.bunpack('s', buff)
  self:writeJSONInteger(val)
end

function TJSONProtocol:writeI32(i32)
  local buff = libluabpack.bpack('i', i32)
  local val = libluabpack.bunpack('i', buff)
  self:writeJSONInteger(val)
end

function TJSONProtocol:writeI64(i64)
  local buff = libluabpack.bpack('l', i64)
  local val = libluabpack.bunpack('l', buff)
  self:writeJSONInteger(tostring(val))
end

function TJSONProtocol:writeDouble(dub)
  self:writeJSONDouble(string.format("%.16f", dub))
end

function TJSONProtocol:writeString(str)
  self:writeJSONString(str)
end

function TJSONProtocol:writeBinary(str)
  -- Should be utf-8
  self:writeJSONBase64(str)
end

function TJSONProtocol:readJSONSyntaxChar(ch)
  local ch2 = ""
  if self.hasReadByte ~= "" then
    ch2 = self.hasReadByte
    self.hasReadByte = ""
  else
    ch2 = self.trans:readAll(1)
  end
  if ch2 ~= ch then
    terror(TProtocolException:new{message = "Expected ".. ch .. ", got " .. ch2})
  end
end

function TJSONProtocol:readElemSeparator()
  if self.jsonContextVal.null then
    return
  end
  if self.jsonContextVal.first then
    self.jsonContextVal.first = false
  else
    if self.jsonContextVal.ttype == 1 then
      if self.jsonContextVal.colon then
        self:readJSONSyntaxChar(JSONNode.PairSeparator)
        self.jsonContextVal.colon = false
      else
        self:readJSONSyntaxChar(JSONNode.ElemSeparator)
        self.jsonContextVal.colon = true
      end
    else
      self:readJSONSyntaxChar(JSONNode.ElemSeparator)
    end
  end
end

function TJSONProtocol:hexVal(ch)
  local val = string.byte(ch)
  if val >= 48 and val <= 57 then
    return val - 48
  elseif val >= 97 and val <= 102 then
    return val - 87
  else
    terror(TProtocolException:new{message = "Expected hex val ([0-9a-f]); got " .. ch})
  end
end

function TJSONProtocol:readJSONEscapeChar(ch)
  self:readJSONSyntaxChar(JSONNode.ZeroChar)
  self:readJSONSyntaxChar(JSONNode.ZeroChar)
  local b1 = self.trans:readAll(1)
  local b2 = self.trans:readAll(1)
  return libluabitwise.shiftl(self:hexVal(b1), 4) + self:hexVal(b2)
end


function TJSONProtocol:readJSONString()
  self:readElemSeparator()
  self:readJSONSyntaxChar(JSONNode.StringDelimiter)
  local result = ""
  while true do
    local ch = self.trans:readAll(1)
    if ch == JSONNode.StringDelimiter then
      break
    end
    if ch == JSONNode.Backslash then
      ch = self.trans:readAll(1)
      if ch == JSONNode.EscapeChar then
        self:readJSONEscapeChar(ch)
      else
        local pos, _ = string.find(JSONNode.EscapeChars, ch)
        if pos == nil then
          terror(TProtocolException:new{message = "Expected control char, got " .. ch})
        end
        ch = EscapeCharVals[pos]
      end
    end
    result = result .. ch
  end
  return result
end

function TJSONProtocol:readJSONBase64()
  local result = self:readJSONString()
  local length = string.len(result)
  local str = ""
  local offset = 1
  while length >= 4 do
    local bytes = string.sub(result, offset, offset+4)
    str = str .. base64_decode(bytes)
    offset = offset + 4
    length = length - 4
  end
  if length >= 0 then
    str = str .. base64_decode(string.sub(result, offset, offset + length))
  end
  return str
end

function TJSONProtocol:readJSONNumericChars()
  local result = ""
  while true do
    local ch = self.trans:readAll(1)
    if string.find(ch, '[-+0-9.Ee]') then
      result = result .. ch
    else
      self.hasReadByte = ch
      break
    end
  end
  return result
end

function TJSONProtocol:readJSONLongInteger()
  self:readElemSeparator()
  if self:escapeNum() then
    self:readJSONSyntaxChar(JSONNode.StringDelimiter)
  end
  local result = self:readJSONNumericChars()
  if self:escapeNum() then
    self:readJSONSyntaxChar(JSONNode.StringDelimiter)
  end
  return result
end

function TJSONProtocol:readJSONInteger()
  return tonumber(self:readJSONLongInteger())
end

function TJSONProtocol:readJSONDouble()
  self:readElemSeparator()
  local delimiter = self.trans:readAll(1)
  local num = 0.0
  if delimiter == JSONNode.StringDelimiter then
    local str = self:readJSONString()
    if str == JSONNode.Nan then
      num = 1.0
    elseif str == JSONNode.Infinity then
      num = math.maxinteger
    elseif str == JSONNode.NegativeInfinity then
      num = math.mininteger
    else
      num = tonumber(str)
    end
  else
    if self:escapeNum() then
      self:readJSONSyntaxChar(JSONNode.StringDelimiter)
    end
    local result = self:readJSONNumericChars()
    num = tonumber(delimiter.. result)
  end
  return num
end

function TJSONProtocol:readJSONObjectBegin()
  self:readElemSeparator()
  self:readJSONSyntaxChar(JSONNode.ObjectBegin)
  self:contextPush({first = true, colon = true, ttype = 1, null = false})
end

function TJSONProtocol:readJSONObjectEnd()
  self:readJSONSyntaxChar(JSONNode.ObjectEnd)
  self:contextPop()
end

function TJSONProtocol:readJSONArrayBegin()
  self:readElemSeparator()
  self:readJSONSyntaxChar(JSONNode.ArrayBegin)
  self:contextPush({first = true, colon = true, ttype = 2, null = false})
end

function TJSONProtocol:readJSONArrayEnd()
  self:readJSONSyntaxChar(JSONNode.ArrayEnd)
  self:contextPop()
end

function TJSONProtocol:readMessageBegin()
  self:resetContext()
  self:readJSONArrayBegin()
  local version = self:readJSONInteger()
  if version ~= self.THRIFT_JSON_PROTOCOL_VERSION then
    terror(TProtocolException:new{message = "Message contained bad version."})
  end
  local name = self:readJSONString()
  local ttype = self:readJSONInteger()
  local seqid = self:readJSONInteger()
  return name, ttype, seqid
end

function TJSONProtocol:readMessageEnd()
  self:readJSONArrayEnd()
end

function TJSONProtocol:readStructBegin()
  self:readJSONObjectBegin()
  return nil
end

function TJSONProtocol:readStructEnd()
  self:readJSONObjectEnd()
end

function TJSONProtocol:readFieldBegin()
  local ttype = TType.STOP
  local id = 0
  local ch = self.trans:readAll(1)
  self.hasReadByte = ch
  if ch ~= JSONNode.ObjectEnd then
    id = self:readJSONInteger()
    self:readJSONObjectBegin()
    local typeName = self:readJSONString()
    ttype = StringToTType[typeName]
  end
  return nil, ttype, id
end

function TJSONProtocol:readFieldEnd()
  self:readJSONObjectEnd()
end

function TJSONProtocol:readMapBegin()
  self:readJSONArrayBegin()
  local typeName = self:readJSONString()
  local ktype = StringToTType[typeName]
  typeName = self:readJSONString()
  local vtype = StringToTType[typeName]
  local size = self:readJSONInteger()
  self:readJSONObjectBegin()
  return ktype, vtype, size
end

function TJSONProtocol:readMapEnd()
  self:readJSONObjectEnd()
  self:readJSONArrayEnd()
end

function TJSONProtocol:readListBegin()
  self:readJSONArrayBegin()
  local typeName = self:readJSONString()
  local etype = StringToTType[typeName]
  local size = self:readJSONInteger()
  return etype, size
end

function TJSONProtocol:readListEnd()
  return self:readJSONArrayEnd()
end

function TJSONProtocol:readSetBegin()
  return self:readListBegin()
end

function TJSONProtocol:readSetEnd()
  return self:readJSONArrayEnd()
end

function TJSONProtocol:readBool()
  local result = self:readJSONInteger()
  if result == 1 then
    return true
  else
    return false
  end
end

function TJSONProtocol:readByte()
  local result = self:readJSONInteger()
  if result >= 256 then
    terror(TProtocolException:new{message = "UnExpected Byte " .. result})
  end
  return result
end

function TJSONProtocol:readI16()
  return self:readJSONInteger()
end

function TJSONProtocol:readI32()
  return self:readJSONInteger()
end

function TJSONProtocol:readI64()
  local long = liblualongnumber.new
  return long(self:readJSONLongInteger())
end

function TJSONProtocol:readDouble()
  return self:readJSONDouble()
end

function TJSONProtocol:readString()
  return self:readJSONString()
end

function TJSONProtocol:readBinary()
  return self:readJSONBase64()
end

TJSONProtocolFactory = TProtocolFactory:new{
  __type = 'TJSONProtocolFactory',
}

function TJSONProtocolFactory:getProtocol(trans)
  -- TODO Enforce that this must be a transport class (ie not a bool)
  if not trans then
    terror(TProtocolException:new{
      message = 'Must supply a transport to ' .. ttype(self)
    })
  end
  return TJSONProtocol:new{
    trans = trans
  }
end