Compare commits

..

24 commits

Author SHA1 Message Date
Renan I. Del Valle
efc0fdcd81
Moving future to final 0.22.0 release and Mesos 1.6.2 (#114)
Changes in compose testing setup:
* Upgrading Aurora to 0.22.0
* Upgrading Mesos to 1.6.2
2020-01-14 15:34:59 -08:00
Renan I. Del Valle
b505304b79
Adding autopause APIs to future (#110)
* Updating thrift definitions to add autopause for batch based update strategies.

* Adding batch calculator utility and test cases for it.

* Adding PauseUpdateMonitor which allows users to poll Aurora for information on an active Update being carried out until it enters the ROLL_FORWARD_PAUSED state.

* Tests for PauseUpdateMonitor and VariableBatchStep added to the end to end tests.

* Adding TerminalUpdateStates function which returns a slice containing all terminal states for an update. Changed signature of JobUpdateStatus from using a map for desired states to a slice. A map is no longer necessary with the new version of thrift and only adds complexity.
2020-01-13 16:03:40 -08:00
Renan DelValle
2148351b94
Variable Batch Update Support (#100)
* Changing generateBinding.sh check to check for thrift 0.12.0 and adding support for Variable Batch updates.

* Adding update strategies change to changelog, changed docker-compose to point to aurora 0.22.0 snapshot. Added test coverage for update strategies.
2019-08-28 17:12:59 -07:00
Renan DelValle
f936576c4d
Increasing aurora version for future branch. 2019-08-28 17:11:32 -07:00
Renan DelValle
df8fc2fba1
Documentation and linting improvements (#108)
* Simplifying documentation for getting started: Removed outdated information about install Golang on different platforms and instead included a link to the official Golang website which has more up to date information. Instructions for installing docker-compose have also been added.

* Added documentation to all exported functions and structs.

* Unexported some structures and functions that were needlessly exported.

* Adding golang CI default configuration which can be useful while developing and may be turned on later in the CI.

* Moving build process in CI to xenial.

* Reducing line size. in some files and shadowing in some test cases.
2019-06-12 11:22:59 -07:00
Renan DelValle
6dc4bf93b9
Retry temporary errors by default (#107)
* Adding Aurora URL validator in order to handle scenarios where incomplete information is passed to the client. The client will do its best to guess the missing information such as protocol and port.

* Upgraded to testify 1.3.0.

* Added configuration to fail on a non-temporary error. This is reverting to the original behavior of the retry mechanism. However, this allows the user to opt to fail in a non-temporary error.
2019-06-11 11:47:14 -07:00
Renan DelValle
4ffb509939
Adding go mod files to v1 (#106)
* Declaring dependencies using go mod.
2019-05-06 11:33:14 -07:00
Renan DelValle
1a15c4a5aa
V1 CreateService and StartJobUpdate Timeout signal and cleanup (#105)
* Bumped up version to 1.21.1

* Moving admin functions to a new file. They are still part of the same pointer receiver type.

* Removing dead code and fixing some comments to add space between backslash and comment.

* Adding set up and tear down to run tests script. It sets up a pod, runs all tests, and then tears down the pod.

* Added `--rm` to run tests Mac script.

* Removing cookie jar from transport layer as it's not needed.

* Changing all error messages to start with a lower case letter. Changing some messages around to be more descriptive.

* Adding an argument to allow the retry mechanism to stop if a timeout has been encountered. This is useful for mutating API calls. Only StartUpdate and CreateService have enabled by default stop at timeout.

* Added 2 tests for when a call goes through despite the client timing out. One is with a good payload, one is with a bad payload.

* Updating changelog with information about the error type returned.

* Adding test for duplicate metadata.

* Refactored JobUpdateStatus monitor to use a new monitor called JobUpdateQuery. Update monitor will now still continue if it does not find an update to monitor. Furthermore, it has been optimized to reduce returning payloads from the scheduler as much as possible. This is through using the GetJobUpdateSummaries API instead of JobUpdateDetails and by including a the statuses we're searching for as part of the query.


* Added documentation as to how to handle a timeout on an API request.

* Optimized GetInstancesIds to create a copy of the JobKey being passed down in order to avoid unexpected behavior. Instead of setting every variable name separately, now a JobKey array is being created.
2019-05-05 11:46:22 -07:00
Renan DelValle
e16e390afe
1.21.0 (formerly 1.4.0) release 2019-03-15 15:15:37 -07:00
Renan DelValle
f7bd7cc20f
Bug fix for metadata duplicates as well as un-initialized GPU re… (#103)
* Fix for metadata duplicates as well.
* Fix for un-initialized GPU resource when creating a new job update.
2019-03-15 15:10:31 -07:00
Renan DelValle
c997b90720
Adding future branch to testing. 2019-03-15 12:17:43 -07:00
Renan DelValle
773d842b03
Adding missing GPU to Job interface. 2019-03-05 11:43:50 -08:00
Renan DelValle
1f459dd56a
Adds support for Tier and SlaPolicy to the Job interface (#99)
* Adding parameter for Aurora so that we're able to run SLA aware updates with less than 20 instances. Lowered time it takes to run test by reducing watch time per instance as well.

* Reducing the number of instances and time for SLA aware instances in docker-compose set up.

* Adding another Mesos agent to the docker-compose setup.

* Huge thanks to @zircote for this contribution.
2019-02-20 16:36:50 -08:00
Renan DelValle
79fa7ba16d
Upgrading gorealis v1 to Thrift 0.12.0 code generation. End to end tests cleanup (#96)
* Ported all code from Thrift 0.9.3 to Thrift 0.12.0 while backporting some fixes from gorealis v2

* Removing git.apache.org dependency from Vendor folder as this dependency has migrated to github.

* Adding github.com thrift dependency back but now it points to github.com

* Removing unnecessary files from Thrift Vendor folder and adding them to .gitignore.

* Updating dep dependencies to include Thrift 0.12.0 from github.com

* Adding changelog.

* End to end tests: Adding coverage for killinstances.

*  End to end tests: Deleting instances after partition policy recovers them.

*  End to end tests: Adding more coverage to the realis API.

*  End to end tests: Allowing arguments to be passed to runTestMac so that '-run <test name>' can be passed in.

*  End to end tests: Reducing the resources used by CreateJob test.

*  End to end tests: Adding coverage for Pause and Resume update.

*   End to end tests: Removed checks for Aurora_OK response as that should always be handled by the error returned by the API. Changed names to be less verbose and repetitive.

*  End to end tests: Reducing watch time for instance running when creating service for reducing time it takes to run end to end test.
2019-02-20 11:11:46 -08:00
Renan DelValle
2b7eb3a852
Making abort job synchronous (#95)
* Making abort job synchronous to avoid scenarios where kill is received before job update lock is released.
* Adding missing cases for terminal update statues to JobUpdate monitor.
* Monitors now return errors which provide context through behavior.
* Adding notes to the doc explaining what happens when AbortJob times out.
2019-01-15 14:55:59 -08:00
Renan DelValle
10c620de7b
Fixing logger not unrolling variadic argument when appending to the front of it. 2019-01-11 12:20:01 -08:00
Renan DelValle
1d3854aa5f
Trace level for logger (#94)
* Add trace level to print out response thrift objects. Allows user to control whether these are printed or not to avoid pollution.

* Using named parameters to be more explicit about what is being set for LevelLogger.

* Adding TracePrint and TracePrintln. Inlined library level prefixes.
2019-01-10 16:58:59 -08:00
Renan DelValle
73e7ab2671
Releasing version 1.3.1 2019-01-08 15:57:19 -08:00
Renan DelValle
22b1d82d88
Bug fix for logger interface. Varidic arguments need to be unrolled when passed to print functions. 2019-01-08 15:37:25 -08:00
Renan DelValle
2f7015571c
Adding support for setting GPU as a resource. (#93)
* Adding support for setting GPU as a resource.
* Refactoring pulse update test.
2019-01-08 15:11:52 -08:00
Robert Allen
296af622d1 This adds the following function to the PartitionPolicy configuration to the Job interface (#91)
* Adding Partition Policy API
2018-12-20 14:38:06 -08:00
Renan DelValle
9a835631b2
Running goimports on all repository to conform to newest goimports. 2018-12-19 15:33:35 -08:00
Renan DelValle
b100158080
Updating Travis CI config file to include running CI on master-v2.0 branch 2018-12-19 15:30:22 -08:00
Renan DelValle
45a4416830
Adding .gitattributes to ignore generated files. 2018-12-03 16:09:46 -08:00
228 changed files with 48960 additions and 15040 deletions

View file

@ -1,5 +0,0 @@
[users]
aurora = secret, admin
[roles]
admin = *

View file

@ -1 +1 @@
0.26.0
0.22.0

View file

@ -1,32 +0,0 @@
name: CI
on:
push:
branches:
- master
pull_request:
branches:
- master
jobs:
build:
runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@v2
- name: Setup Go for use with actions
uses: actions/setup-go@v2
with:
go-version: 1.17
- name: Install goimports
run: go get golang.org/x/tools/cmd/goimports
- name: Set env with list of directories in repo containin go code
run: echo GO_USR_DIRS=$(go list -f {{.Dir}} ./... | grep -E -v "/gen-go/|/vendor/") >> $GITHUB_ENV
- name: Run goimports check
run: test -z "`for d in $GO_USR_DIRS; do goimports -d $d/*.go | tee /dev/stderr; done`"
- name: Create aurora/mesos docker cluster
run: docker-compose up -d
- name: Run tests
run: go test -timeout 35m -race -coverprofile=coverage.txt -covermode=atomic -v github.com/aurora-scheduler/gorealis/v2

4
.gitignore vendored
View file

@ -37,7 +37,3 @@ _testmain.go
# Output of the go coverage tool, specifically when used with LiteIDE
*.out
# Example client build
examples/client
examples/jsonClient

71
.golangci.yml Normal file
View file

@ -0,0 +1,71 @@
# This file contains all available configuration options
# with their default values.
# options for analysis running
run:
# default concurrency is a available CPU number
concurrency: 4
# timeout for analysis, e.g. 30s, 5m, default is 1m
deadline: 1m
# exit code when at least one issue was found, default is 1
issues-exit-code: 1
# include test files or not, default is true
tests: true
skip-dirs:
- gen-go/
# output configuration options
output:
# colored-line-number|line-number|json|tab|checkstyle|code-climate, default is "colored-line-number"
format: colored-line-number
# print lines of code with issue, default is true
print-issued-lines: true
# print linter name in the end of issue text, default is true
print-linter-name: true
# all available settings of specific linters
linters-settings:
errcheck:
# report about not checking of errors in type assetions: `a := b.(MyStruct)`;
# default is false: such cases aren't reported by default.
check-type-assertions: true
# report about assignment of errors to blank identifier: `num, _ := strconv.Atoi(numStr)`;
# default is false: such cases aren't reported by default.
check-blank: true
govet:
# report about shadowed variables
check-shadowing: true
goconst:
# minimal length of string constant, 3 by default
min-len: 3
# minimal occurrences count to trigger, 3 by default
min-occurrences: 2
misspell:
# Correct spellings using locale preferences for US or UK.
# Default is to use a neutral variety of English.
# Setting locale to US will correct the British spelling of 'colour' to 'color'.
locale: US
lll:
# max line length, lines longer will be reported. Default is 120.
# '\t' is counted as 1 character by default, and can be changed with the tab-width option
line-length: 120
# tab width in spaces. Default to 1.
tab-width: 4
linters:
enable:
- govet
- goimports
- golint
- lll
- goconst
enable-all: false
fast: false

View file

@ -1,9 +1,16 @@
sudo: required
dist: xenial
language: go
branches:
only:
- master
- master-v2.0
- future
go:
- "1.11.x"
- "1.10.x"
env:
global:
@ -20,7 +27,7 @@ install:
- docker-compose up -d
script:
- go test -race -coverprofile=coverage.txt -covermode=atomic -v github.com/aurora-scheduler/gorealis
- go test -race -coverprofile=coverage.txt -covermode=atomic -v github.com/paypal/gorealis
after_success:
- bash <(curl -s https://codecov.io/bash)

25
CHANGELOG.md Normal file
View file

@ -0,0 +1,25 @@
1.22.0 (unreleased)
* CreateService and StartJobUpdate do not continue retrying if a timeout has been encountered
by the HTTP client. Instead they now return an error that conforms to the Timedout interface.
Users can check for a Timedout error by using `realis.IsTimeout(err)`.
* New API function VariableBatchStep has been added which returns the current batch at which
a Variable Batch Update configured Update is currently in.
* Added new PauseUpdateMonitor which monitors an update until it is an `ROLL_FORWARD_PAUSED` state.
* Added variableBatchStep command to sample client to be used for testing new VariableBatchStep api.
* JobUpdateStatus has changed function signature from:
`JobUpdateStatus(updateKey aurora.JobUpdateKey, desiredStatuses map[aurora.JobUpdateStatus]bool, interval, timeout time.Duration) (aurora.JobUpdateStatus, error)`
to
`JobUpdateStatus(updateKey aurora.JobUpdateKey, desiredStatuses []aurora.JobUpdateStatus, interval, timeout time.Duration) (aurora.JobUpdateStatus, error)`
* Added TerminalUpdateStates function which returns an slice containing all UpdateStates which are considered terminal states.
1.21.0
* Version numbering change. Future versions will be labled X.Y.Z where X is the major version, Y is the Aurora version the library has been tested against (e.g. 21 -> 0.21.0), and X is the minor revision.
* Moved to Thrift 0.12.0 code generator and go library.
* `aurora.ACTIVE_STATES`, `aurora.SLAVE_ASSIGNED_STATES`, `aurora.LIVE_STATES`, `aurora.TERMINAL_STATES`, `aurora.ACTIVE_JOB_UPDATE_STATES`, `aurora.AWAITNG_PULSE_JOB_UPDATE_STATES` are all now generated as a slices.
* Please use `realis.ActiveStates`, `realis.SlaveAssignedStates`,`realis.LiveStates`, `realis.TerminalStates`, `realis.ActiveJobUpdateStates`, `realis.AwaitingPulseJobUpdateStates` in their places when map representations are needed.
* `GetInstanceIds(key *aurora.JobKey, states map[aurora.ScheduleStatus]bool) (map[int32]bool, error)` has changed signature to ` GetInstanceIds(key *aurora.JobKey, states []aurora.ScheduleStatus) ([]int32, error)`
* Adding support for GPU as resource.
* Changing compose environment to Aurora snapshot in order to support staggered update.
* Adding staggered updates API.

64
Gopkg.lock generated Normal file
View file

@ -0,0 +1,64 @@
# This file is autogenerated, do not edit; changes may be undone by the next 'dep ensure'.
[[projects]]
digest = "1:89696c38cec777120b8b1bb5e2d363d655cf2e1e7d8c851919aaa0fd576d9b86"
name = "github.com/apache/thrift"
packages = ["lib/go/thrift"]
pruneopts = ""
revision = "384647d290e2e4a55a14b1b7ef1b7e66293a2c33"
version = "v0.12.0"
[[projects]]
digest = "1:56c130d885a4aacae1dd9c7b71cfe39912c7ebc1ff7d2b46083c8812996dc43b"
name = "github.com/davecgh/go-spew"
packages = ["spew"]
pruneopts = ""
revision = "346938d642f2ec3594ed81d874461961cd0faa76"
version = "v1.1.0"
[[projects]]
digest = "1:df48fb76fb2a40edea0c9b3d960bc95e326660d82ff1114e1f88001f7a236b40"
name = "github.com/pkg/errors"
packages = ["."]
pruneopts = ""
revision = "e881fd58d78e04cf6d0de1217f8707c8cc2249bc"
[[projects]]
digest = "1:256484dbbcd271f9ecebc6795b2df8cad4c458dd0f5fd82a8c2fa0c29f233411"
name = "github.com/pmezard/go-difflib"
packages = ["difflib"]
pruneopts = ""
revision = "792786c7400a136282c1664665ae0a8db921c6c2"
version = "v1.0.0"
[[projects]]
digest = "1:78bea5e26e82826dacc5fd64a1013a6711b7075ec8072819b89e6ad76cb8196d"
name = "github.com/samuel/go-zookeeper"
packages = ["zk"]
pruneopts = ""
revision = "471cd4e61d7a78ece1791fa5faa0345dc8c7d5a5"
[[projects]]
digest = "1:381bcbeb112a51493d9d998bbba207a529c73dbb49b3fd789e48c63fac1f192c"
name = "github.com/stretchr/testify"
packages = [
"assert",
"require",
]
pruneopts = ""
revision = "ffdc059bfe9ce6a4e144ba849dbedead332c6053"
version = "v1.3.0"
[solve-meta]
analyzer-name = "dep"
analyzer-version = 1
input-imports = [
"github.com/apache/thrift/lib/go/thrift",
"github.com/pkg/errors",
"github.com/samuel/go-zookeeper/zk",
"github.com/stretchr/testify/assert",
"github.com/stretchr/testify/require",
]
solver-name = "gps-cdcl"
solver-version = 1

16
Gopkg.toml Normal file
View file

@ -0,0 +1,16 @@
[[constraint]]
name = "github.com/apache/thrift"
version = "0.12.0"
[[constraint]]
name = "github.com/pkg/errors"
revision = "e881fd58d78e04cf6d0de1217f8707c8cc2249bc"
[[constraint]]
name = "github.com/samuel/go-zookeeper"
revision = "471cd4e61d7a78ece1791fa5faa0345dc8c7d5a5"
[[constraint]]
name = "github.com/stretchr/testify"
version = "1.3.0"

View file

@ -1,6 +1,6 @@
# gorealis [![GoDoc](https://godoc.org/github.com/aurora-scheduler/gorealis?status.svg)](https://godoc.org/github.com/aurora-scheduler/gorealis) [![codecov](https://codecov.io/gh/aurora-scheduler/gorealis/branch/master/graph/badge.svg)](https://codecov.io/gh/aurora-scheduler/gorealis/branch/master)
# gorealis [![GoDoc](https://godoc.org/github.com/paypal/gorealis?status.svg)](https://godoc.org/github.com/paypal/gorealis) [![Build Status](https://travis-ci.org/paypal/gorealis.svg?branch=master)](https://travis-ci.org/paypal/gorealis) [![codecov](https://codecov.io/gh/paypal/gorealis/branch/master/graph/badge.svg)](https://codecov.io/gh/paypal/gorealis)
Go library for interacting with [Aurora Scheduler](https://github.com/aurora-scheduler/aurora).
Go library for interacting with [Apache Aurora](https://github.com/apache/aurora).
### Aurora version compatibility
Please see [.auroraversion](./.auroraversion) to see the latest Aurora version against which this
@ -14,7 +14,7 @@ library has been tested.
## Projects using gorealis
* [australis](https://github.com/aurora-scheduler/australis)
* [australis](https://github.com/rdelval/australis)
## Contributions
Contributions are always welcome. Please raise an issue to discuss a contribution before it is made.

View file

@ -21,6 +21,8 @@ import (
"github.com/pkg/errors"
)
// Cluster contains the definition of the clusters.json file used by the default Aurora
// client for configuration
type Cluster struct {
Name string `json:"name"`
AgentRoot string `json:"slave_root"`
@ -28,13 +30,13 @@ type Cluster struct {
ZK string `json:"zk"`
ZKPort int `json:"zk_port"`
SchedZKPath string `json:"scheduler_zk_path"`
MesosZKPath string `json:"mesos_zk_path"`
SchedURI string `json:"scheduler_uri"`
ProxyURL string `json:"proxy_url"`
AuthMechanism string `json:"auth_mechanism"`
}
// Loads clusters.json file traditionally located at /etc/aurora/clusters.json
// LoadClusters loads clusters.json file traditionally located at /etc/aurora/clusters.json
// for use with a gorealis client
func LoadClusters(config string) (map[string]Cluster, error) {
file, err := os.Open(config)
@ -55,15 +57,3 @@ func LoadClusters(config string) (map[string]Cluster, error) {
return m, nil
}
func GetDefaultClusterFromZKUrl(zkURL string) *Cluster {
return &Cluster{
Name: "defaultCluster",
AuthMechanism: "UNAUTHENTICATED",
ZK: zkURL,
SchedZKPath: "/aurora/scheduler",
MesosZKPath: "/mesos",
AgentRunDir: "latest",
AgentRoot: "/var/lib/mesos",
}
}

View file

@ -18,7 +18,7 @@ import (
"fmt"
"testing"
realis "github.com/aurora-scheduler/gorealis/v2"
realis "github.com/paypal/gorealis"
"github.com/stretchr/testify/assert"
)
@ -32,7 +32,6 @@ func TestLoadClusters(t *testing.T) {
assert.Equal(t, clusters["devcluster"].Name, "devcluster")
assert.Equal(t, clusters["devcluster"].ZK, "192.168.33.7")
assert.Equal(t, clusters["devcluster"].SchedZKPath, "/aurora/scheduler")
assert.Equal(t, clusters["devcluster"].MesosZKPath, "/mesos")
assert.Equal(t, clusters["devcluster"].AuthMechanism, "UNAUTHENTICATED")
assert.Equal(t, clusters["devcluster"].AgentRunDir, "latest")
assert.Equal(t, clusters["devcluster"].AgentRoot, "/var/lib/mesos")

View file

@ -15,31 +15,44 @@
package realis
import (
"github.com/aurora-scheduler/gorealis/v2/gen-go/apache/aurora"
"github.com/paypal/gorealis/gen-go/apache/aurora"
)
// Container is an interface that defines a single function needed to create
// an Aurora container type. It exists because the code must support both Mesos
// and Docker containers.
type Container interface {
Build() *aurora.Container
}
// MesosContainer is a Mesos style container that can be used by Aurora Jobs.
type MesosContainer struct {
container *aurora.MesosContainer
}
// DockerContainer is a vanilla Docker style container that can be used by Aurora Jobs.
type DockerContainer struct {
container *aurora.DockerContainer
}
func NewDockerContainer() *DockerContainer {
return &DockerContainer{container: aurora.NewDockerContainer()}
// NewDockerContainer creates a new Aurora compatible Docker container configuration.
func NewDockerContainer() DockerContainer {
return DockerContainer{container: aurora.NewDockerContainer()}
}
func (c *DockerContainer) Build() *aurora.Container {
// Build creates an Aurora container based upon the configuration provided.
func (c DockerContainer) Build() *aurora.Container {
return &aurora.Container{Docker: c.container}
}
func (c *DockerContainer) Image(image string) *DockerContainer {
// Image adds the name of a Docker image to be used by the Job when running.
func (c DockerContainer) Image(image string) DockerContainer {
c.container.Image = image
return c
}
func (c *DockerContainer) AddParameter(name, value string) *DockerContainer {
// AddParameter adds a parameter to be passed to Docker when the container is run.
func (c DockerContainer) AddParameter(name, value string) DockerContainer {
c.container.Parameters = append(c.container.Parameters, &aurora.DockerParameter{
Name: name,
Value: value,
@ -47,19 +60,18 @@ func (c *DockerContainer) AddParameter(name, value string) *DockerContainer {
return c
}
type MesosContainer struct {
container *aurora.MesosContainer
// NewMesosContainer creates a Mesos style container to be configured and built for use by an Aurora Job.
func NewMesosContainer() MesosContainer {
return MesosContainer{container: aurora.NewMesosContainer()}
}
func NewMesosContainer() *MesosContainer {
return &MesosContainer{container: aurora.NewMesosContainer()}
}
func (c *MesosContainer) Build() *aurora.Container {
// Build creates a Mesos style Aurora container configuration to be passed on to the Aurora Job.
func (c MesosContainer) Build() *aurora.Container {
return &aurora.Container{Mesos: c.container}
}
func (c *MesosContainer) DockerImage(name, tag string) *MesosContainer {
// DockerImage configures the Mesos container to use a specific Docker image when being run.
func (c MesosContainer) DockerImage(name, tag string) MesosContainer {
if c.container.Image == nil {
c.container.Image = aurora.NewImage()
}
@ -68,20 +80,12 @@ func (c *MesosContainer) DockerImage(name, tag string) *MesosContainer {
return c
}
func (c *MesosContainer) AppcImage(name, imageId string) *MesosContainer {
// AppcImage configures the Mesos container to use an image in the Appc format to run the container.
func (c MesosContainer) AppcImage(name, imageID string) MesosContainer {
if c.container.Image == nil {
c.container.Image = aurora.NewImage()
}
c.container.Image.Appc = &aurora.AppcImage{Name: name, ImageId: imageId}
return c
}
func (c *MesosContainer) AddVolume(hostPath, containerPath string, mode aurora.Mode) *MesosContainer {
c.container.Volumes = append(c.container.Volumes, &aurora.Volume{
HostPath: hostPath,
ContainerPath: containerPath,
Mode: mode})
c.container.Image.Appc = &aurora.AppcImage{Name: name, ImageId: imageID}
return c
}

View file

@ -14,7 +14,7 @@ services:
ipv4_address: 192.168.33.2
master:
image: quay.io/aurorascheduler/mesos-master:1.9.0
image: rdelvalle/mesos-master:1.6.2
restart: on-failure
ports:
- "5050:5050"
@ -32,7 +32,7 @@ services:
- zk
agent-one:
image: quay.io/aurorascheduler/mesos-agent:1.9.0
image: rdelvalle/mesos-agent:1.6.2
pid: host
restart: on-failure
ports:
@ -41,11 +41,10 @@ services:
MESOS_MASTER: zk://192.168.33.2:2181/mesos
MESOS_CONTAINERIZERS: docker,mesos
MESOS_PORT: 5051
MESOS_HOSTNAME: agent-one
MESOS_HOSTNAME: localhost
MESOS_RESOURCES: ports(*):[11000-11999]
MESOS_SYSTEMD_ENABLE_SUPPORT: 'false'
MESOS_WORK_DIR: /tmp/mesos
MESOS_ATTRIBUTES: 'host:agent-one;rack:1;zone:west'
networks:
aurora_cluster:
ipv4_address: 192.168.33.4
@ -57,57 +56,31 @@ services:
- zk
agent-two:
image: quay.io/aurorascheduler/mesos-agent:1.9.0
image: rdelvalle/mesos-agent:1.6.2
pid: host
restart: on-failure
ports:
- "5052:5051"
- "5061:5061"
environment:
MESOS_MASTER: zk://192.168.33.2:2181/mesos
MESOS_CONTAINERIZERS: docker,mesos
MESOS_PORT: 5051
MESOS_HOSTNAME: agent-two
MESOS_HOSTNAME: localhost
MESOS_PORT: 5061
MESOS_RESOURCES: ports(*):[11000-11999]
MESOS_SYSTEMD_ENABLE_SUPPORT: 'false'
MESOS_WORK_DIR: /tmp/mesos
MESOS_ATTRIBUTES: 'host:agent-two;rack:2;zone:west'
networks:
aurora_cluster:
ipv4_address: 192.168.33.5
volumes:
- /sys/fs/cgroup:/sys/fs/cgroup
- /var/run/docker.sock:/var/run/docker.sock
- /sys/fs/cgroup:/sys/fs/cgroup
- /var/run/docker.sock:/var/run/docker.sock
depends_on:
- zk
agent-three:
image: quay.io/aurorascheduler/mesos-agent:1.9.0
pid: host
restart: on-failure
ports:
- "5053:5051"
environment:
MESOS_MASTER: zk://192.168.33.2:2181/mesos
MESOS_CONTAINERIZERS: docker,mesos
MESOS_PORT: 5051
MESOS_HOSTNAME: agent-three
MESOS_RESOURCES: ports(*):[11000-11999]
MESOS_SYSTEMD_ENABLE_SUPPORT: 'false'
MESOS_WORK_DIR: /tmp/mesos
MESOS_ATTRIBUTES: 'host:agent-three;rack:2;zone:west;dedicated:vagrant/bar'
networks:
aurora_cluster:
ipv4_address: 192.168.33.6
volumes:
- /sys/fs/cgroup:/sys/fs/cgroup
- /var/run/docker.sock:/var/run/docker.sock
depends_on:
- zk
- zk
aurora-one:
image: quay.io/aurorascheduler/scheduler:0.25.0
image: rdelvalle/aurora:0.22.0
pid: host
ports:
- "8081:8081"
@ -116,14 +89,7 @@ services:
CLUSTER_NAME: test-cluster
ZK_ENDPOINTS: "192.168.33.2:2181"
MESOS_MASTER: "zk://192.168.33.2:2181/mesos"
EXTRA_SCHEDULER_ARGS: >
-http_authentication_mechanism=BASIC
-shiro_realm_modules=INI_AUTHNZ
-shiro_ini_path=/etc/aurora/security.ini
-min_required_instances_for_sla_check=1
-thermos_executor_cpu=0.09
volumes:
- ./.aurora-config:/etc/aurora
EXTRA_SCHEDULER_ARGS: "-min_required_instances_for_sla_check=1"
networks:
aurora_cluster:
ipv4_address: 192.168.33.7

View file

@ -19,18 +19,25 @@ This also allows us to delete and recreate our development cluster very quickly.
To install docker-compose please follow the instructions for your platform
[here](https://docs.docker.com/compose/install/).
### Getting the source code
`$ git clone https://github.com/aurora-scheduler/gorealis`
As of go 1.10.x, GOPATH is still relevant. This may change in the future but
for the sake of making development less error prone, it is suggested that the following
directories be created:
Inside of the newly cloned repo you may download dependencies to the local cache using go mod
`$ mkdir -p $GOPATH/src/github.com/paypal`
`$ go mod download`
And then clone the master branch into the newly created folder:
`$ cd $GOPATH/src/github.com/paypal; git clone git@github.com:paypal/gorealis.git`
Since we check in our vendor folder, gorealis no further set up is needed.
### Bringing up the cluster
To develop gorealis, you will need a fully functioning Mesos cluster along with
the Aurora Scheduler.
To develop gorealis, you will need a fully functioning Mesos cluster along with
Apache Aurora.
In order to bring up our docker-compose set up execute the following command from the root
of the git repository:
@ -55,14 +62,14 @@ environment but not when running under MacOS. To run code involving the ZK leade
For example, running the tests in a container can be done through the following command from
the root of the git repository:
`$ docker run -t -v $(pwd):/go/src/github.com/aurora-scheduler/gorealis --network gorealis_aurora_cluster golang:1.14.3-alpine go test github.com/paypal/gorealis`
`$ docker run -t -v $(pwd):/go/src/github.com/paypal/gorealis --network gorealis_aurora_cluster golang:1.10.3-alpine go test github.com/paypal/gorealis`
Or
`$ ./runTestsMac.sh`
Alternatively, if an interactive shell is necessary, the following command may be used:
`$ docker run -it -v $(pwd):/go/src/github.com/paypal/gorealis --network gorealis_aurora_cluster golang:1.14.3-alpine /bin/sh`
`$ docker run -it -v $(pwd):/go/src/github.com/paypal/gorealis --network gorealis_aurora_cluster golang:1.10.3-alpine /bin/sh`
### Cleaning up the cluster
@ -78,3 +85,6 @@ Once development is done, the environment may be torn down by executing (from th
git directory):
`$ docker-compose down`

View file

@ -88,12 +88,6 @@ On Ubuntu, restarting the aurora-scheduler can be achieved by running the follow
$ sudo service aurora-scheduler restart
```
### Using a custom client
Pystachio does not yet support launching tasks using custom executors. Therefore, a custom
client must be used in order to launch tasks using a custom executor. In this case,
we will be using [gorealis](https://github.com/paypal/gorealis) to launch a task with
the compose executor on Aurora.
## Using [dce-go](https://github.com/paypal/dce-go)
Instead of manually configuring Aurora to run the docker-compose executor, one can follow the instructions provided [here](https://github.com/paypal/dce-go/blob/develop/docs/environment.md) to quickly create a DCE environment that would include mesos, aurora, golang1.7, docker, docker-compose and DCE installed.
@ -107,80 +101,12 @@ Mesos endpoint --> http://192.168.33.8:5050
### Installing Go
#### Linux
Follow the instructions at the official golang website: [golang.org/doc/install](https://golang.org/doc/install)
##### Ubuntu
### Installing docker-compose
###### Adding a PPA and install via apt-get
```
$ sudo add-apt-repository ppa:ubuntu-lxc/lxd-stable
$ sudo apt-get update
$ sudo apt-get install golang
```
###### Configuring the GOPATH
Configure the environment to be able to compile and run Go code.
```
$ mkdir $HOME/go
$ echo export GOPATH=$HOME/go >> $HOME/.bashrc
$ echo export GOROOT=/usr/lib/go >> $HOME/.bashrc
$ echo export PATH=$PATH:$GOPATH/bin >> $HOME/.bashrc
$ echo export PATH=$PATH:$GOROOT/bin >> $HOME/.bashrc
```
Finally we must reload the .bashrc configuration:
```
$ source $HOME/.bashrc
```
#### OS X
One way to install go on OS X is by using [Homebrew](http://brew.sh/)
##### Installing Homebrew
Run the following command from the terminal to install Hombrew:
```
$ /usr/bin/ruby -e "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/master/install)"
```
##### Installing Go using Hombrew
Run the following command from the terminal to install Go:
```
$ brew install go
```
##### Configuring the GOPATH
Configure the environment to be able to compile and run Go code.
```
$ mkdir $HOME/go
$ echo export GOPATH=$HOME/go >> $HOME/.profile
$ echo export GOROOT=/usr/local/opt/go/libexec >> $HOME/.profile
$ echo export PATH=$PATH:$GOPATH/bin >> $HOME/.profile
$ echo export PATH=$PATH:$GOROOT/bin >> $HOME/.profile
```
Finally we must reload the .profile configuration:
```
$ source $HOME/.profile
```
#### Windows
Download and run the msi installer from https://golang.org/dl/
## Installing Docker Compose (if manually configured Aurora)
To show Aurora's new multi executor feature, we need to use at least one custom executor.
In this case we will be using the [docker-compose-executor](https://github.com/mesos/docker-compose-executor).
In order to run the docker-compose executor, each agent must have docker-compose installed on it.
This can be done using pip:
```
$ sudo pip install docker-compose
```
Agents which will run dce-go will need docker-compose in order to sucessfully run the executor.
Instructions for installing docker-compose on various platforms may be found on Docker's webiste: [docs.docker.com/compose/install/](https://docs.docker.com/compose/install/)
## Downloading gorealis
Finally, we must get `gorealis` using the `go get` command:
@ -192,7 +118,7 @@ go get github.com/paypal/gorealis
# Creating Aurora Jobs
## Creating a thermos job
To demonstrate that we are able to run jobs using different executors on the
To demonstrate that we are able to run jobs using different executors on the
same scheduler, we'll first launch a thermos job using the default Aurora Client.
We can use a sample job for this:
@ -247,9 +173,6 @@ job = realis.NewJob().
RAM(64).
Disk(100).
IsService(false).
Production(false).
Tier("preemptible").
Priority(0).
InstanceCount(1).
AddPorts(1).
AddLabel("fileName", "sample-app/docker-compose.yml").
@ -262,8 +185,8 @@ go run $GOPATH/src/github.com/paypal/gorealis/examples/client.go -executor=compo
```
If everything went according to plan, a new job will be shown in the Aurora UI.
We can further investigate inside the Mesos task sandbox. Inside the sandbox, under
the sample-app folder, we can find a docker-compose.yml-generated.yml. If we inspect this file,
We can further investigate inside the Mesos task sandbox. Inside the sandbox, under
the sample-app folder, we can find a docker-compose.yml-generated.yml. If we inspect this file,
we can find the port at which we can find the web server we launched.
Under Web->Ports, we find the port Mesos allocated. We can then navigate to:
@ -272,10 +195,10 @@ Under Web->Ports, we find the port Mesos allocated. We can then navigate to:
A message from the executor should greet us.
## Creating a Thermos job using gorealis
It is also possible to create a thermos job using gorealis. To do this, however,
It is also possible to create a thermos job using gorealis. To do this, however,
a thermos payload is required. A thermos payload consists of a JSON blob that details
the entire task as it exists inside the Aurora Scheduler. *Creating the blob is unfortunately
out of the scope of what gorealis does*, so a thermos payload must be generated beforehand or
out of the scope of what gorealis does*, so a thermos payload must be generated beforehand or
retrieved from the structdump of an existing task for testing purposes.
A sample thermos JSON payload may be found [here](../examples/thermos_payload.json) in the examples folder.
@ -294,9 +217,6 @@ job = realis.NewJob().
RAM(64).
Disk(100).
IsService(true).
Production(false).
Tier("preemptible").
Priority(0).
InstanceCount(1).
AddPorts(1)
```

View file

@ -25,9 +25,6 @@ job = realis.NewJob().
RAM(64).
Disk(100).
IsService(false).
Production(false).
Tier("preemptible").
Priority(0).
InstanceCount(1).
AddPorts(1).
AddLabel("fileName", "sample-app/docker-compose.yml").
@ -60,4 +57,19 @@ updateJob := realis.NewUpdateJob(job)
updateJob.InstanceCount(1)
updateJob.Ram(128)
msg, err := r.UpdateJob(updateJob, "")
```
```
* Handling a timeout scenario:
When sending an API call to Aurora, the call may timeout at the client side.
This means that the time limit has been reached while waiting for the scheduler
to reply. In such a case it is recommended that the timeout is increased through
the use of the `realis.TimeoutMS()` option.
As these timeouts cannot be totally avoided, there exists a mechanism to mitigate such
scenarios. The `StartJobUpdate` and `CreateService` API will return an error that
implements the Timeout interface.
An error can be checked to see if it is a Timeout error by using the `realis.IsTimeout()`
function.

View file

@ -1,6 +1,6 @@
# Using the Sample client
## Usage:
## Usage:
```
Usage of ./client:
-cluster string
@ -22,25 +22,28 @@ Usage of ./client:
```
## Sample commands:
These commands are set to run on a vagrant box. To be able to run the docker compose
executor examples, the vagrant box must be configured properly to use the docker compose executor.
### Thermos
#### Creating a Thermos job
```
$ go run examples/client.go -url=http://localhost:8081 -executor=thermos -cmd=create
$ cd $GOPATH/src/github.com/paypal/gorealis/examples
$ go run client.go -executor=thermos -url=http://192.168.33.7:8081 -cmd=create
```
#### Kill a Thermos job
```
$ go run examples/client.go -url=http://localhost:8081 -executor=thermos -cmd=kill
$ go run $GOPATH/src/github.com/paypal/gorealis/examples/client.go -executor=thermos -url=http://192.168.33.7:8081 -cmd=kill
```
### Docker Compose executor (custom executor)
#### Creating Docker Compose executor job
```
$ go run examples/client.go -url=http://192.168.33.7:8081 -executor=compose -cmd=create
$ go run $GOPATH/src/github.com/paypal/gorealis/examples/client.go -executor=compose -url=http://192.168.33.7:8081 -cmd=create
```
#### Kill a Docker Compose executor job
```
$ go run examples/client.go -url=http://192.168.33.7:8081 -executor=compose -cmd=kill
$ go run $GOPATH/src/github.com/paypal/gorealis/examples/client.go -executor=compose -url=http://192.168.33.7:8081 -cmd=kill
```

View file

@ -17,12 +17,14 @@ package realis
// Using a pattern described by Dave Cheney to differentiate errors
// https://dave.cheney.net/2016/04/27/dont-just-check-errors-handle-them-gracefully
// Timedout errors are returned when a function is unable to continue executing due
// Timeout errors are returned when a function is unable to continue executing due
// to a time constraint or meeting a set number of retries.
type timeout interface {
Timedout() bool
}
// IsTimeout returns true if the error being passed as an argument implements the Timeout interface
// and the Timedout function returns true.
func IsTimeout(err error) bool {
temp, ok := err.(timeout)
return ok && temp.Timedout()
@ -61,41 +63,42 @@ func (r *retryErr) RetryCount() int {
return r.retryCount
}
// Helper function for testing verification to avoid whitebox testing
// ToRetryCount is a helper function for testing verification to avoid whitebox testing
// as well as keeping retryErr as a private.
// Should NOT be used under any other context.
func ToRetryCount(err error) *retryErr {
if retryErr, ok := err.(*retryErr); ok {
return retryErr
} else {
return nil
}
return nil
}
func newRetryError(err error, retryCount int) *retryErr {
return &retryErr{error: err, timedout: true, retryCount: retryCount}
}
// Temporary errors indicate that the action may and should be retried.
// Temporary errors indicate that the action may or should be retried.
type temporary interface {
Temporary() bool
}
// IsTemporary indicates whether the error passed in as an argument implements the temporary interface
// and if the Temporary function returns true.
func IsTemporary(err error) bool {
temp, ok := err.(temporary)
return ok && temp.Temporary()
}
type TemporaryErr struct {
type temporaryErr struct {
error
temporary bool
}
func (t *TemporaryErr) Temporary() bool {
func (t *temporaryErr) Temporary() bool {
return t.temporary
}
// Retrying after receiving this error is advised
func NewTemporaryError(err error) *TemporaryErr {
return &TemporaryErr{error: err, temporary: true}
// NewTemporaryError creates a new error which satisfies the Temporary interface.
func NewTemporaryError(err error) *temporaryErr {
return &temporaryErr{error: err, temporary: true}
}

View file

@ -17,22 +17,24 @@ package main
import (
"flag"
"fmt"
"io/ioutil"
"log"
"strings"
"time"
realis "github.com/aurora-scheduler/gorealis/v2"
"github.com/aurora-scheduler/gorealis/v2/gen-go/apache/aurora"
realis "github.com/paypal/gorealis"
"github.com/paypal/gorealis/gen-go/apache/aurora"
"github.com/paypal/gorealis/response"
)
var cmd, executor, url, clustersConfig, clusterName, updateId, username, password, zkUrl, hostList, role string
var caCertsPath string
var clientKey, clientCert string
var ConnectionTimeout = 20 * time.Second
var ConnectionTimeout = 20000
func init() {
flag.StringVar(&cmd, "cmd", "", "Aurora Job request type to send to Aurora Scheduler")
flag.StringVar(&cmd, "cmd", "", "Job request type to send to Aurora Scheduler")
flag.StringVar(&executor, "executor", "thermos", "Executor to use")
flag.StringVar(&url, "url", "", "URL at which the Aurora Scheduler exists as [url]:[port]")
flag.StringVar(&clustersConfig, "clusters", "", "Location of the clusters.json file used by aurora.")
@ -72,14 +74,15 @@ func init() {
func main() {
var job *realis.AuroraJob
var job realis.Job
var err error
var r *realis.Client
var monitor *realis.Monitor
var r realis.Realis
clientOptions := []realis.ClientOption{
realis.BasicAuth(username, password),
realis.ThriftJSON(),
realis.Timeout(ConnectionTimeout),
realis.TimeoutMS(ConnectionTimeout),
realis.BackOff(realis.Backoff{
Steps: 2,
Duration: 10 * time.Second,
@ -97,39 +100,39 @@ func main() {
}
if caCertsPath != "" {
clientOptions = append(clientOptions, realis.CertsPath(caCertsPath))
clientOptions = append(clientOptions, realis.Certspath(caCertsPath))
}
if clientKey != "" && clientCert != "" {
clientOptions = append(clientOptions, realis.ClientCerts(clientKey, clientCert))
}
r, err = realis.NewClient(clientOptions...)
r, err = realis.NewRealisClient(clientOptions...)
if err != nil {
log.Fatalln(err)
}
monitor = &realis.Monitor{r}
defer r.Close()
switch executor {
case "thermos":
thermosExec := realis.ThermosExecutor{}
thermosExec.AddProcess(realis.NewThermosProcess("boostrap", "echo bootsrapping")).
AddProcess(realis.NewThermosProcess("hello_gorealis", "while true; do echo hello world from gorealis; sleep 10; done"))
payload, err := ioutil.ReadFile("examples/thermos_payload.json")
if err != nil {
log.Fatalln("Error reading json config file: ", err)
}
job = realis.NewJob().
Environment("prod").
Role("vagrant").
Name("hello_world_from_gorealis").
ExecutorName(aurora.AURORA_EXECUTOR_NAME).
ExecutorData(string(payload)).
CPU(1).
RAM(64).
Disk(100).
IsService(true).
Production(false).
Tier("preemptible").
Priority(0).
InstanceCount(1).
AddPorts(1).
ThermosExecutor(thermosExec)
AddPorts(1)
case "compose":
job = realis.NewJob().
Environment("prod").
@ -141,9 +144,6 @@ func main() {
RAM(512).
Disk(100).
IsService(true).
Production(false).
Tier("preemptible").
Priority(0).
InstanceCount(1).
AddPorts(4).
AddLabel("fileName", "sample-app/docker-compose.yml").
@ -157,9 +157,6 @@ func main() {
RAM(64).
Disk(100).
IsService(true).
Production(false).
Tier("preemptible").
Priority(0).
InstanceCount(1).
AddPorts(1)
default:
@ -169,13 +166,14 @@ func main() {
switch cmd {
case "create":
fmt.Println("Creating job")
err := r.CreateJob(job)
resp, err := r.CreateJob(job)
if err != nil {
log.Fatalln(err)
}
fmt.Println(resp.String())
if ok, mErr := r.MonitorInstances(job.JobKey(), job.GetInstanceCount(), 5*time.Second, 50*time.Second); !ok || mErr != nil {
err := r.KillJob(job.JobKey())
if ok, mErr := monitor.Instances(job.JobKey(), job.GetInstanceCount(), 5, 50); !ok || mErr != nil {
_, err := r.KillJob(job.JobKey())
if err != nil {
log.Fatalln(err)
}
@ -185,17 +183,18 @@ func main() {
case "createService":
// Create a service with three instances using the update API instead of the createJob API
fmt.Println("Creating service")
settings := realis.JobUpdateFromAuroraTask(job.AuroraTask()).InstanceCount(3)
result, err := r.CreateService(settings)
settings := realis.NewUpdateSettings()
job.InstanceCount(3)
resp, result, err := r.CreateService(job, settings)
if err != nil {
log.Fatal("error: ", err)
log.Println("error: ", err)
log.Fatal("response: ", resp.String())
}
fmt.Println(result.String())
if ok, mErr := r.MonitorJobUpdate(*result.GetKey(), 5*time.Second, 180*time.Second); !ok || mErr != nil {
err := r.AbortJobUpdate(*result.GetKey(), "Monitor timed out")
err = r.KillJob(job.JobKey())
if ok, mErr := monitor.JobUpdate(*result.GetKey(), 5, 180); !ok || mErr != nil {
_, err := r.AbortJobUpdate(*result.GetKey(), "Monitor timed out")
_, err = r.KillJob(job.JobKey())
if err != nil {
log.Fatal(err)
}
@ -206,13 +205,14 @@ func main() {
fmt.Println("Creating a docker based job")
container := realis.NewDockerContainer().Image("python:2.7").AddParameter("network", "host")
job.Container(container)
err := r.CreateJob(job)
resp, err := r.CreateJob(job)
if err != nil {
log.Fatal(err)
}
fmt.Println(resp.String())
if ok, err := r.MonitorInstances(job.JobKey(), job.GetInstanceCount(), 10*time.Second, 300*time.Second); !ok || err != nil {
err := r.KillJob(job.JobKey())
if ok, err := monitor.Instances(job.JobKey(), job.GetInstanceCount(), 10, 300); !ok || err != nil {
_, err := r.KillJob(job.JobKey())
if err != nil {
log.Fatal(err)
}
@ -222,13 +222,14 @@ func main() {
fmt.Println("Creating a docker based job")
container := realis.NewMesosContainer().DockerImage("python", "2.7")
job.Container(container)
err := r.CreateJob(job)
resp, err := r.CreateJob(job)
if err != nil {
log.Fatal(err)
}
fmt.Println(resp.String())
if ok, err := r.MonitorInstances(job.JobKey(), job.GetInstanceCount(), 10*time.Second, 300*time.Second); !ok || err != nil {
err := r.KillJob(job.JobKey())
if ok, err := monitor.Instances(job.JobKey(), job.GetInstanceCount(), 10, 300); !ok || err != nil {
_, err := r.KillJob(job.JobKey())
if err != nil {
log.Fatal(err)
}
@ -239,44 +240,50 @@ func main() {
// Cron config
job.CronSchedule("* * * * *")
job.IsService(false)
err := r.ScheduleCronJob(job)
resp, err := r.ScheduleCronJob(job)
if err != nil {
log.Fatal(err)
}
fmt.Println(resp.String())
case "startCron":
fmt.Println("Starting a Cron job")
err := r.StartCronJob(job.JobKey())
resp, err := r.StartCronJob(job.JobKey())
if err != nil {
log.Fatal(err)
}
fmt.Println(resp.String())
case "descheduleCron":
fmt.Println("Descheduling a Cron job")
err := r.DescheduleCronJob(job.JobKey())
resp, err := r.DescheduleCronJob(job.JobKey())
if err != nil {
log.Fatal(err)
}
fmt.Println(resp.String())
case "kill":
fmt.Println("Killing job")
err := r.KillJob(job.JobKey())
resp, err := r.KillJob(job.JobKey())
if err != nil {
log.Fatal(err)
}
if ok, err := r.MonitorInstances(job.JobKey(), 0, 5*time.Second, 50*time.Second); !ok || err != nil {
if ok, err := monitor.Instances(job.JobKey(), 0, 5, 50); !ok || err != nil {
log.Fatal("Unable to kill all instances of job")
}
fmt.Println(resp.String())
case "restart":
fmt.Println("Restarting job")
err := r.RestartJob(job.JobKey())
resp, err := r.RestartJob(job.JobKey())
if err != nil {
log.Fatal(err)
}
fmt.Println(resp.String())
case "liveCount":
fmt.Println("Getting instance count")
@ -295,110 +302,106 @@ func main() {
log.Fatal(err)
}
fmt.Println("Active instances: ", live)
fmt.Println("Number of live instances: ", len(live))
case "flexUp":
fmt.Println("Flexing up job")
numOfInstances := 4
numOfInstances := int32(4)
live, err := r.GetInstanceIds(job.JobKey(), aurora.ACTIVE_STATES)
if err != nil {
log.Fatal(err)
}
currInstances := len(live)
currInstances := int32(len(live))
fmt.Println("Current num of instances: ", currInstances)
key := job.JobKey()
err = r.AddInstances(aurora.InstanceKey{
JobKey: &key,
resp, err := r.AddInstances(aurora.InstanceKey{
JobKey: job.JobKey(),
InstanceId: live[0],
},
int32(numOfInstances))
numOfInstances)
if err != nil {
log.Fatal(err)
}
if ok, err := r.MonitorInstances(job.JobKey(), int32(currInstances+numOfInstances), 5*time.Second, 50*time.Second); !ok || err != nil {
if ok, err := monitor.Instances(job.JobKey(), currInstances+numOfInstances, 5, 50); !ok || err != nil {
fmt.Println("Flexing up failed")
}
fmt.Println(resp.String())
case "flexDown":
fmt.Println("Flexing down job")
numOfInstances := 2
numOfInstances := int32(2)
live, err := r.GetInstanceIds(job.JobKey(), aurora.ACTIVE_STATES)
if err != nil {
log.Fatal(err)
}
currInstances := len(live)
currInstances := int32(len(live))
fmt.Println("Current num of instances: ", currInstances)
err = r.RemoveInstances(job.JobKey(), numOfInstances)
resp, err := r.RemoveInstances(job.JobKey(), numOfInstances)
if err != nil {
log.Fatal(err)
}
if ok, err := r.MonitorInstances(job.JobKey(), int32(currInstances-numOfInstances), 5*time.Second, 100*time.Second); !ok || err != nil {
if ok, err := monitor.Instances(job.JobKey(), currInstances-numOfInstances, 5, 100); !ok || err != nil {
fmt.Println("flexDown failed")
}
fmt.Println(resp.String())
case "update":
fmt.Println("Updating a job with with more RAM and to 5 instances")
live, err := r.GetInstanceIds(job.JobKey(), aurora.ACTIVE_STATES)
if err != nil {
log.Fatal(err)
}
key := job.JobKey()
taskConfig, err := r.FetchTaskConfig(aurora.InstanceKey{
JobKey: &key,
JobKey: job.JobKey(),
InstanceId: live[0],
})
if err != nil {
log.Fatal(err)
}
updateJob := realis.JobUpdateFromConfig(taskConfig).InstanceCount(5).RAM(128)
updateJob := realis.NewDefaultUpdateJob(taskConfig)
updateJob.InstanceCount(5).RAM(128)
result, err := r.StartJobUpdate(updateJob, "")
resp, err := r.StartJobUpdate(updateJob, "")
if err != nil {
log.Fatal(err)
}
jobUpdateKey := result.GetKey()
_, err = r.MonitorJobUpdate(*jobUpdateKey, 5*time.Second, 6*time.Minute)
if err != nil {
log.Fatal(err)
}
jobUpdateKey := response.JobUpdateKey(resp)
monitor.JobUpdate(*jobUpdateKey, 5, 500)
case "pauseJobUpdate":
key := job.JobKey()
err := r.PauseJobUpdate(&aurora.JobUpdateKey{
Job: &key,
resp, err := r.PauseJobUpdate(&aurora.JobUpdateKey{
Job: job.JobKey(),
ID: updateId,
}, "")
if err != nil {
log.Fatal(err)
}
fmt.Println("PauseJobUpdate response: ", resp.String())
case "resumeJobUpdate":
key := job.JobKey()
err := r.ResumeJobUpdate(aurora.JobUpdateKey{
Job: &key,
resp, err := r.ResumeJobUpdate(&aurora.JobUpdateKey{
Job: job.JobKey(),
ID: updateId,
}, "")
if err != nil {
log.Fatal(err)
}
fmt.Println("ResumeJobUpdate response: ", resp.String())
case "pulseJobUpdate":
key := job.JobKey()
resp, err := r.PulseJobUpdate(aurora.JobUpdateKey{
Job: &key,
resp, err := r.PulseJobUpdate(&aurora.JobUpdateKey{
Job: job.JobKey(),
ID: updateId,
})
if err != nil {
@ -408,10 +411,9 @@ func main() {
fmt.Println("PulseJobUpdate response: ", resp.String())
case "updateDetails":
key := job.JobKey()
result, err := r.JobUpdateDetails(aurora.JobUpdateQuery{
resp, err := r.JobUpdateDetails(aurora.JobUpdateQuery{
Key: &aurora.JobUpdateKey{
Job: &key,
Job: job.JobKey(),
ID: updateId,
},
Limit: 1,
@ -421,13 +423,12 @@ func main() {
log.Fatal(err)
}
fmt.Println(result)
fmt.Println(response.JobUpdateDetails(resp))
case "abortUpdate":
fmt.Println("Abort update")
key := job.JobKey()
err := r.AbortJobUpdate(aurora.JobUpdateKey{
Job: &key,
resp, err := r.AbortJobUpdate(aurora.JobUpdateKey{
Job: job.JobKey(),
ID: updateId,
},
"")
@ -435,12 +436,12 @@ func main() {
if err != nil {
log.Fatal(err)
}
fmt.Println(resp.String())
case "rollbackUpdate":
fmt.Println("Abort update")
key := job.JobKey()
err := r.RollbackJobUpdate(aurora.JobUpdateKey{
Job: &key,
resp, err := r.RollbackJobUpdate(aurora.JobUpdateKey{
Job: job.JobKey(),
ID: updateId,
},
"")
@ -448,6 +449,14 @@ func main() {
if err != nil {
log.Fatal(err)
}
fmt.Println(resp.String())
case "variableBatchStep":
step, err := r.VariableBatchStep(aurora.JobUpdateKey{Job: job.JobKey(), ID: updateId})
if err != nil {
log.Fatal(err)
}
fmt.Println(step)
case "taskConfig":
fmt.Println("Getting job info")
@ -456,9 +465,8 @@ func main() {
log.Fatal(err)
}
key := job.JobKey()
config, err := r.FetchTaskConfig(aurora.InstanceKey{
JobKey: &key,
JobKey: job.JobKey(),
InstanceId: live[0],
})
@ -470,10 +478,9 @@ func main() {
case "updatesummary":
fmt.Println("Getting job update summary")
key := job.JobKey()
jobquery := &aurora.JobUpdateQuery{
Role: &key.Role,
JobKey: &key,
Role: &job.JobKey().Role,
JobKey: job.JobKey(),
}
updatesummary, err := r.GetJobUpdateSummaries(jobquery)
if err != nil {
@ -484,11 +491,10 @@ func main() {
case "taskStatus":
fmt.Println("Getting task status")
key := job.JobKey()
taskQ := &aurora.TaskQuery{
Role: &key.Role,
Environment: &key.Environment,
JobName: &key.Name,
Role: &job.JobKey().Role,
Environment: &job.JobKey().Environment,
JobName: &job.JobKey().Name,
}
tasks, err := r.GetTaskStatus(taskQ)
if err != nil {
@ -500,11 +506,10 @@ func main() {
case "tasksWithoutConfig":
fmt.Println("Getting task status")
key := job.JobKey()
taskQ := &aurora.TaskQuery{
Role: &key.Role,
Environment: &key.Environment,
JobName: &key.Name,
Role: &job.JobKey().Role,
Environment: &job.JobKey().Environment,
JobName: &job.JobKey().Name,
}
tasks, err := r.GetTasksWithoutConfigs(taskQ)
if err != nil {
@ -520,17 +525,17 @@ func main() {
log.Fatal("No hosts specified to drain")
}
hosts := strings.Split(hostList, ",")
_, err := r.DrainHosts(hosts...)
_, result, err := r.DrainHosts(hosts...)
if err != nil {
log.Fatalf("error: %+v\n", err.Error())
}
// Monitor change to DRAINING and DRAINED mode
hostResult, err := r.MonitorHostMaintenance(
hostResult, err := monitor.HostMaintenance(
hosts,
[]aurora.MaintenanceMode{aurora.MaintenanceMode_DRAINED, aurora.MaintenanceMode_DRAINING},
5*time.Second,
10*time.Second)
5,
10)
if err != nil {
for host, ok := range hostResult {
if !ok {
@ -540,6 +545,8 @@ func main() {
log.Fatalf("error: %+v\n", err.Error())
}
fmt.Print(result.String())
case "SLADrainHosts":
fmt.Println("Setting hosts to DRAINING using SLA aware draining")
if hostList == "" {
@ -549,17 +556,17 @@ func main() {
policy := aurora.SlaPolicy{PercentageSlaPolicy: &aurora.PercentageSlaPolicy{Percentage: 50.0}}
_, err := r.SLADrainHosts(&policy, 30, hosts...)
result, err := r.SLADrainHosts(&policy, 30, hosts...)
if err != nil {
log.Fatalf("error: %+v\n", err.Error())
}
// Monitor change to DRAINING and DRAINED mode
hostResult, err := r.MonitorHostMaintenance(
hostResult, err := monitor.HostMaintenance(
hosts,
[]aurora.MaintenanceMode{aurora.MaintenanceMode_DRAINED, aurora.MaintenanceMode_DRAINING},
5*time.Second,
10*time.Second)
5,
10)
if err != nil {
for host, ok := range hostResult {
if !ok {
@ -569,23 +576,25 @@ func main() {
log.Fatalf("error: %+v\n", err.Error())
}
fmt.Print(result.String())
case "endMaintenance":
fmt.Println("Setting hosts to ACTIVE")
if hostList == "" {
log.Fatal("No hosts specified to drain")
}
hosts := strings.Split(hostList, ",")
_, err := r.EndMaintenance(hosts...)
_, result, err := r.EndMaintenance(hosts...)
if err != nil {
log.Fatalf("error: %+v\n", err.Error())
}
// Monitor change to DRAINING and DRAINED mode
hostResult, err := r.MonitorHostMaintenance(
hostResult, err := monitor.HostMaintenance(
hosts,
[]aurora.MaintenanceMode{aurora.MaintenanceMode_NONE},
5*time.Second,
10*time.Second)
5,
10)
if err != nil {
for host, ok := range hostResult {
if !ok {
@ -595,13 +604,14 @@ func main() {
log.Fatalf("error: %+v\n", err.Error())
}
fmt.Print(result.String())
case "getPendingReasons":
fmt.Println("Getting pending reasons")
key := job.JobKey()
taskQ := &aurora.TaskQuery{
Role: &key.Role,
Environment: &key.Environment,
JobName: &key.Name,
Role: &job.JobKey().Role,
Environment: &job.JobKey().Environment,
JobName: &job.JobKey().Name,
}
reasons, err := r.GetPendingReason(taskQ)
if err != nil {
@ -613,7 +623,7 @@ func main() {
case "getJobs":
fmt.Println("GetJobs...role: ", role)
result, err := r.GetJobs(role)
_, result, err := r.GetJobs(role)
if err != nil {
log.Fatalf("error: %+v\n", err.Error())
}

View file

@ -2,7 +2,6 @@
"name": "devcluster",
"zk": "192.168.33.7",
"scheduler_zk_path": "/aurora/scheduler",
"mesos_zk_path": "/mesos",
"auth_mechanism": "UNAUTHENTICATED",
"slave_run_directory": "latest",
"slave_root": "/var/lib/mesos"

View file

@ -23,8 +23,8 @@ import (
"os"
"time"
realis "github.com/aurora-scheduler/gorealis/v2"
"github.com/aurora-scheduler/gorealis/v2/gen-go/apache/aurora"
realis "github.com/paypal/gorealis"
"github.com/paypal/gorealis/gen-go/apache/aurora"
"github.com/pkg/errors"
)
@ -125,7 +125,7 @@ func init() {
}
}
func CreateRealisClient(config *Config) (*realis.Client, error) {
func CreateRealisClient(config *Config) (realis.Realis, error) {
var transportOption realis.ClientOption
// Configuring transport protocol. If not transport is provided, then using JSON as the
// default transport protocol.
@ -157,7 +157,7 @@ func CreateRealisClient(config *Config) (*realis.Client, error) {
clientOptions = append(clientOptions, realis.Debug())
}
return realis.NewClient(clientOptions...)
return realis.NewRealisClient(clientOptions...)
}
func main() {
@ -165,6 +165,7 @@ func main() {
fmt.Println(clientCreationErr)
os.Exit(1)
} else {
monitor := &realis.Monitor{Client: r}
defer r.Close()
uris := job.URIs
labels := job.Labels
@ -177,8 +178,6 @@ func main() {
RAM(job.RAM).
Disk(job.Disk).
IsService(job.Service).
Tier("preemptible").
Priority(0).
InstanceCount(job.Instances).
AddPorts(job.Ports)
@ -206,18 +205,20 @@ func main() {
}
fmt.Println("Creating Job...")
if jobCreationErr := r.CreateJob(auroraJob); jobCreationErr != nil {
if resp, jobCreationErr := r.CreateJob(auroraJob); jobCreationErr != nil {
fmt.Println("Error creating Aurora job: ", jobCreationErr)
os.Exit(1)
} else {
if ok, monitorErr := r.MonitorInstances(auroraJob.JobKey(), auroraJob.GetInstanceCount(), 5, 50); !ok || monitorErr != nil {
if jobErr := r.KillJob(auroraJob.JobKey()); jobErr !=
nil {
fmt.Println(jobErr)
os.Exit(1)
} else {
fmt.Println("ok: ", ok)
fmt.Println("jobErr: ", jobErr)
if resp.ResponseCode == aurora.ResponseCode_OK {
if ok, monitorErr := monitor.Instances(auroraJob.JobKey(), auroraJob.GetInstanceCount(), 5, 50); !ok || monitorErr != nil {
if _, jobErr := r.KillJob(auroraJob.JobKey()); jobErr !=
nil {
fmt.Println(jobErr)
os.Exit(1)
} else {
fmt.Println("ok: ", ok)
fmt.Println("jobErr: ", jobErr)
}
}
}
}

View file

@ -0,0 +1,62 @@
{
"environment": "prod",
"health_check_config": {
"initial_interval_secs": 15.0,
"health_checker": {
"http": {
"expected_response_code": 0,
"endpoint": "/health",
"expected_response": "ok"
}
},
"interval_secs": 10.0,
"timeout_secs": 1.0,
"max_consecutive_failures": 0
},
"name": "hello_world_from_gorealis",
"service": false,
"max_task_failures": 1,
"cron_collision_policy": "KILL_EXISTING",
"enable_hooks": false,
"cluster": "devcluster",
"task": {
"processes": [
{
"daemon": false,
"name": "hello",
"ephemeral": false,
"max_failures": 1,
"min_duration": 5,
"cmdline": "\n while true; do\n echo hello world from gorealis\n sleep 10\n done\n ",
"final": false
}
],
"name": "hello",
"finalization_wait": 30,
"max_failures": 1,
"max_concurrency": 0,
"resources": {
"gpu": 0,
"disk": 134217728,
"ram": 134217728,
"cpu": 1.0
},
"constraints": [
{
"order": [
"hello"
]
}
]
},
"production": false,
"role": "vagrant",
"lifecycle": {
"http": {
"graceful_shutdown_endpoint": "/quitquitquit",
"port": "health",
"shutdown_endpoint": "/abortabortabort"
}
},
"priority": 0
}

View file

@ -1,28 +0,0 @@
{
"task": {
"processes": [
{
"daemon": false,
"name": "hello",
"ephemeral": false,
"max_failures": 1,
"min_duration": 5,
"cmdline": "\n while true; do\n echo hello world from gorealis\n sleep 10\n done\n ",
"final": false
}
],
"resources": {
"gpu": 0,
"disk": 134217728,
"ram": 134217728,
"cpu": 1.1
},
"constraints": [
{
"order": [
"hello"
]
}
]
}
}

View file

@ -1,4 +1,5 @@
// Code generated by Thrift Compiler (0.14.0). DO NOT EDIT.
// Autogenerated by Thrift Compiler (0.12.0)
// DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING
package aurora

View file

@ -1,12 +1,13 @@
// Code generated by Thrift Compiler (0.14.0). DO NOT EDIT.
// Autogenerated by Thrift Compiler (0.12.0)
// DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING
package aurora
import(
import (
"bytes"
"context"
"reflect"
"fmt"
"time"
"github.com/apache/thrift/lib/go/thrift"
)
@ -14,7 +15,7 @@ import(
var _ = thrift.ZERO
var _ = fmt.Printf
var _ = context.Background
var _ = time.Now
var _ = reflect.DeepEqual
var _ = bytes.Equal
const AURORA_EXECUTOR_NAME = "AuroraExecutor"

File diff suppressed because it is too large Load diff

View file

@ -1,22 +1,22 @@
// Code generated by Thrift Compiler (0.14.0). DO NOT EDIT.
// Autogenerated by Thrift Compiler (0.12.0)
// DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING
package main
import (
"context"
"flag"
"fmt"
"math"
"net"
"net/url"
"os"
"strconv"
"strings"
"github.com/apache/thrift/lib/go/thrift"
"apache/aurora"
"context"
"flag"
"fmt"
"math"
"net"
"net/url"
"os"
"strconv"
"strings"
"github.com/apache/thrift/lib/go/thrift"
"apache/aurora"
)
var _ = aurora.GoUnusedProtection__
func Usage() {
fmt.Fprintln(os.Stderr, "Usage of ", os.Args[0], " [-h host:port] [-u url] [-f[ramed]] function [arg1 [arg2...]]:")
@ -175,19 +175,19 @@ func main() {
fmt.Fprintln(os.Stderr, "CreateJob requires 1 args")
flag.Usage()
}
arg213 := flag.Arg(1)
mbTrans214 := thrift.NewTMemoryBufferLen(len(arg213))
defer mbTrans214.Close()
_, err215 := mbTrans214.WriteString(arg213)
if err215 != nil {
arg163 := flag.Arg(1)
mbTrans164 := thrift.NewTMemoryBufferLen(len(arg163))
defer mbTrans164.Close()
_, err165 := mbTrans164.WriteString(arg163)
if err165 != nil {
Usage()
return
}
factory216 := thrift.NewTJSONProtocolFactory()
jsProt217 := factory216.GetProtocol(mbTrans214)
factory166 := thrift.NewTJSONProtocolFactory()
jsProt167 := factory166.GetProtocol(mbTrans164)
argvalue0 := aurora.NewJobConfiguration()
err218 := argvalue0.Read(context.Background(), jsProt217)
if err218 != nil {
err168 := argvalue0.Read(jsProt167)
if err168 != nil {
Usage()
return
}
@ -200,19 +200,19 @@ func main() {
fmt.Fprintln(os.Stderr, "ScheduleCronJob requires 1 args")
flag.Usage()
}
arg219 := flag.Arg(1)
mbTrans220 := thrift.NewTMemoryBufferLen(len(arg219))
defer mbTrans220.Close()
_, err221 := mbTrans220.WriteString(arg219)
if err221 != nil {
arg169 := flag.Arg(1)
mbTrans170 := thrift.NewTMemoryBufferLen(len(arg169))
defer mbTrans170.Close()
_, err171 := mbTrans170.WriteString(arg169)
if err171 != nil {
Usage()
return
}
factory222 := thrift.NewTJSONProtocolFactory()
jsProt223 := factory222.GetProtocol(mbTrans220)
factory172 := thrift.NewTJSONProtocolFactory()
jsProt173 := factory172.GetProtocol(mbTrans170)
argvalue0 := aurora.NewJobConfiguration()
err224 := argvalue0.Read(context.Background(), jsProt223)
if err224 != nil {
err174 := argvalue0.Read(jsProt173)
if err174 != nil {
Usage()
return
}
@ -225,19 +225,19 @@ func main() {
fmt.Fprintln(os.Stderr, "DescheduleCronJob requires 1 args")
flag.Usage()
}
arg225 := flag.Arg(1)
mbTrans226 := thrift.NewTMemoryBufferLen(len(arg225))
defer mbTrans226.Close()
_, err227 := mbTrans226.WriteString(arg225)
if err227 != nil {
arg175 := flag.Arg(1)
mbTrans176 := thrift.NewTMemoryBufferLen(len(arg175))
defer mbTrans176.Close()
_, err177 := mbTrans176.WriteString(arg175)
if err177 != nil {
Usage()
return
}
factory228 := thrift.NewTJSONProtocolFactory()
jsProt229 := factory228.GetProtocol(mbTrans226)
factory178 := thrift.NewTJSONProtocolFactory()
jsProt179 := factory178.GetProtocol(mbTrans176)
argvalue0 := aurora.NewJobKey()
err230 := argvalue0.Read(context.Background(), jsProt229)
if err230 != nil {
err180 := argvalue0.Read(jsProt179)
if err180 != nil {
Usage()
return
}
@ -250,19 +250,19 @@ func main() {
fmt.Fprintln(os.Stderr, "StartCronJob requires 1 args")
flag.Usage()
}
arg231 := flag.Arg(1)
mbTrans232 := thrift.NewTMemoryBufferLen(len(arg231))
defer mbTrans232.Close()
_, err233 := mbTrans232.WriteString(arg231)
if err233 != nil {
arg181 := flag.Arg(1)
mbTrans182 := thrift.NewTMemoryBufferLen(len(arg181))
defer mbTrans182.Close()
_, err183 := mbTrans182.WriteString(arg181)
if err183 != nil {
Usage()
return
}
factory234 := thrift.NewTJSONProtocolFactory()
jsProt235 := factory234.GetProtocol(mbTrans232)
factory184 := thrift.NewTJSONProtocolFactory()
jsProt185 := factory184.GetProtocol(mbTrans182)
argvalue0 := aurora.NewJobKey()
err236 := argvalue0.Read(context.Background(), jsProt235)
if err236 != nil {
err186 := argvalue0.Read(jsProt185)
if err186 != nil {
Usage()
return
}
@ -275,36 +275,36 @@ func main() {
fmt.Fprintln(os.Stderr, "RestartShards requires 2 args")
flag.Usage()
}
arg237 := flag.Arg(1)
mbTrans238 := thrift.NewTMemoryBufferLen(len(arg237))
defer mbTrans238.Close()
_, err239 := mbTrans238.WriteString(arg237)
if err239 != nil {
arg187 := flag.Arg(1)
mbTrans188 := thrift.NewTMemoryBufferLen(len(arg187))
defer mbTrans188.Close()
_, err189 := mbTrans188.WriteString(arg187)
if err189 != nil {
Usage()
return
}
factory240 := thrift.NewTJSONProtocolFactory()
jsProt241 := factory240.GetProtocol(mbTrans238)
factory190 := thrift.NewTJSONProtocolFactory()
jsProt191 := factory190.GetProtocol(mbTrans188)
argvalue0 := aurora.NewJobKey()
err242 := argvalue0.Read(context.Background(), jsProt241)
if err242 != nil {
err192 := argvalue0.Read(jsProt191)
if err192 != nil {
Usage()
return
}
value0 := argvalue0
arg243 := flag.Arg(2)
mbTrans244 := thrift.NewTMemoryBufferLen(len(arg243))
defer mbTrans244.Close()
_, err245 := mbTrans244.WriteString(arg243)
if err245 != nil {
arg193 := flag.Arg(2)
mbTrans194 := thrift.NewTMemoryBufferLen(len(arg193))
defer mbTrans194.Close()
_, err195 := mbTrans194.WriteString(arg193)
if err195 != nil {
Usage()
return
}
factory246 := thrift.NewTJSONProtocolFactory()
jsProt247 := factory246.GetProtocol(mbTrans244)
factory196 := thrift.NewTJSONProtocolFactory()
jsProt197 := factory196.GetProtocol(mbTrans194)
containerStruct1 := aurora.NewAuroraSchedulerManagerRestartShardsArgs()
err248 := containerStruct1.ReadField2(context.Background(), jsProt247)
if err248 != nil {
err198 := containerStruct1.ReadField2(jsProt197)
if err198 != nil {
Usage()
return
}
@ -318,36 +318,36 @@ func main() {
fmt.Fprintln(os.Stderr, "KillTasks requires 3 args")
flag.Usage()
}
arg249 := flag.Arg(1)
mbTrans250 := thrift.NewTMemoryBufferLen(len(arg249))
defer mbTrans250.Close()
_, err251 := mbTrans250.WriteString(arg249)
if err251 != nil {
arg199 := flag.Arg(1)
mbTrans200 := thrift.NewTMemoryBufferLen(len(arg199))
defer mbTrans200.Close()
_, err201 := mbTrans200.WriteString(arg199)
if err201 != nil {
Usage()
return
}
factory252 := thrift.NewTJSONProtocolFactory()
jsProt253 := factory252.GetProtocol(mbTrans250)
factory202 := thrift.NewTJSONProtocolFactory()
jsProt203 := factory202.GetProtocol(mbTrans200)
argvalue0 := aurora.NewJobKey()
err254 := argvalue0.Read(context.Background(), jsProt253)
if err254 != nil {
err204 := argvalue0.Read(jsProt203)
if err204 != nil {
Usage()
return
}
value0 := argvalue0
arg255 := flag.Arg(2)
mbTrans256 := thrift.NewTMemoryBufferLen(len(arg255))
defer mbTrans256.Close()
_, err257 := mbTrans256.WriteString(arg255)
if err257 != nil {
arg205 := flag.Arg(2)
mbTrans206 := thrift.NewTMemoryBufferLen(len(arg205))
defer mbTrans206.Close()
_, err207 := mbTrans206.WriteString(arg205)
if err207 != nil {
Usage()
return
}
factory258 := thrift.NewTJSONProtocolFactory()
jsProt259 := factory258.GetProtocol(mbTrans256)
factory208 := thrift.NewTJSONProtocolFactory()
jsProt209 := factory208.GetProtocol(mbTrans206)
containerStruct1 := aurora.NewAuroraSchedulerManagerKillTasksArgs()
err260 := containerStruct1.ReadField2(context.Background(), jsProt259)
if err260 != nil {
err210 := containerStruct1.ReadField2(jsProt209)
if err210 != nil {
Usage()
return
}
@ -363,25 +363,25 @@ func main() {
fmt.Fprintln(os.Stderr, "AddInstances requires 2 args")
flag.Usage()
}
arg262 := flag.Arg(1)
mbTrans263 := thrift.NewTMemoryBufferLen(len(arg262))
defer mbTrans263.Close()
_, err264 := mbTrans263.WriteString(arg262)
if err264 != nil {
arg212 := flag.Arg(1)
mbTrans213 := thrift.NewTMemoryBufferLen(len(arg212))
defer mbTrans213.Close()
_, err214 := mbTrans213.WriteString(arg212)
if err214 != nil {
Usage()
return
}
factory265 := thrift.NewTJSONProtocolFactory()
jsProt266 := factory265.GetProtocol(mbTrans263)
factory215 := thrift.NewTJSONProtocolFactory()
jsProt216 := factory215.GetProtocol(mbTrans213)
argvalue0 := aurora.NewInstanceKey()
err267 := argvalue0.Read(context.Background(), jsProt266)
if err267 != nil {
err217 := argvalue0.Read(jsProt216)
if err217 != nil {
Usage()
return
}
value0 := argvalue0
tmp1, err268 := (strconv.Atoi(flag.Arg(2)))
if err268 != nil {
tmp1, err218 := (strconv.Atoi(flag.Arg(2)))
if err218 != nil {
Usage()
return
}
@ -395,19 +395,19 @@ func main() {
fmt.Fprintln(os.Stderr, "ReplaceCronTemplate requires 1 args")
flag.Usage()
}
arg269 := flag.Arg(1)
mbTrans270 := thrift.NewTMemoryBufferLen(len(arg269))
defer mbTrans270.Close()
_, err271 := mbTrans270.WriteString(arg269)
if err271 != nil {
arg219 := flag.Arg(1)
mbTrans220 := thrift.NewTMemoryBufferLen(len(arg219))
defer mbTrans220.Close()
_, err221 := mbTrans220.WriteString(arg219)
if err221 != nil {
Usage()
return
}
factory272 := thrift.NewTJSONProtocolFactory()
jsProt273 := factory272.GetProtocol(mbTrans270)
factory222 := thrift.NewTJSONProtocolFactory()
jsProt223 := factory222.GetProtocol(mbTrans220)
argvalue0 := aurora.NewJobConfiguration()
err274 := argvalue0.Read(context.Background(), jsProt273)
if err274 != nil {
err224 := argvalue0.Read(jsProt223)
if err224 != nil {
Usage()
return
}
@ -420,19 +420,19 @@ func main() {
fmt.Fprintln(os.Stderr, "StartJobUpdate requires 2 args")
flag.Usage()
}
arg275 := flag.Arg(1)
mbTrans276 := thrift.NewTMemoryBufferLen(len(arg275))
defer mbTrans276.Close()
_, err277 := mbTrans276.WriteString(arg275)
if err277 != nil {
arg225 := flag.Arg(1)
mbTrans226 := thrift.NewTMemoryBufferLen(len(arg225))
defer mbTrans226.Close()
_, err227 := mbTrans226.WriteString(arg225)
if err227 != nil {
Usage()
return
}
factory278 := thrift.NewTJSONProtocolFactory()
jsProt279 := factory278.GetProtocol(mbTrans276)
factory228 := thrift.NewTJSONProtocolFactory()
jsProt229 := factory228.GetProtocol(mbTrans226)
argvalue0 := aurora.NewJobUpdateRequest()
err280 := argvalue0.Read(context.Background(), jsProt279)
if err280 != nil {
err230 := argvalue0.Read(jsProt229)
if err230 != nil {
Usage()
return
}
@ -447,19 +447,19 @@ func main() {
fmt.Fprintln(os.Stderr, "PauseJobUpdate requires 2 args")
flag.Usage()
}
arg282 := flag.Arg(1)
mbTrans283 := thrift.NewTMemoryBufferLen(len(arg282))
defer mbTrans283.Close()
_, err284 := mbTrans283.WriteString(arg282)
if err284 != nil {
arg232 := flag.Arg(1)
mbTrans233 := thrift.NewTMemoryBufferLen(len(arg232))
defer mbTrans233.Close()
_, err234 := mbTrans233.WriteString(arg232)
if err234 != nil {
Usage()
return
}
factory285 := thrift.NewTJSONProtocolFactory()
jsProt286 := factory285.GetProtocol(mbTrans283)
factory235 := thrift.NewTJSONProtocolFactory()
jsProt236 := factory235.GetProtocol(mbTrans233)
argvalue0 := aurora.NewJobUpdateKey()
err287 := argvalue0.Read(context.Background(), jsProt286)
if err287 != nil {
err237 := argvalue0.Read(jsProt236)
if err237 != nil {
Usage()
return
}
@ -474,19 +474,19 @@ func main() {
fmt.Fprintln(os.Stderr, "ResumeJobUpdate requires 2 args")
flag.Usage()
}
arg289 := flag.Arg(1)
mbTrans290 := thrift.NewTMemoryBufferLen(len(arg289))
defer mbTrans290.Close()
_, err291 := mbTrans290.WriteString(arg289)
if err291 != nil {
arg239 := flag.Arg(1)
mbTrans240 := thrift.NewTMemoryBufferLen(len(arg239))
defer mbTrans240.Close()
_, err241 := mbTrans240.WriteString(arg239)
if err241 != nil {
Usage()
return
}
factory292 := thrift.NewTJSONProtocolFactory()
jsProt293 := factory292.GetProtocol(mbTrans290)
factory242 := thrift.NewTJSONProtocolFactory()
jsProt243 := factory242.GetProtocol(mbTrans240)
argvalue0 := aurora.NewJobUpdateKey()
err294 := argvalue0.Read(context.Background(), jsProt293)
if err294 != nil {
err244 := argvalue0.Read(jsProt243)
if err244 != nil {
Usage()
return
}
@ -501,19 +501,19 @@ func main() {
fmt.Fprintln(os.Stderr, "AbortJobUpdate requires 2 args")
flag.Usage()
}
arg296 := flag.Arg(1)
mbTrans297 := thrift.NewTMemoryBufferLen(len(arg296))
defer mbTrans297.Close()
_, err298 := mbTrans297.WriteString(arg296)
if err298 != nil {
arg246 := flag.Arg(1)
mbTrans247 := thrift.NewTMemoryBufferLen(len(arg246))
defer mbTrans247.Close()
_, err248 := mbTrans247.WriteString(arg246)
if err248 != nil {
Usage()
return
}
factory299 := thrift.NewTJSONProtocolFactory()
jsProt300 := factory299.GetProtocol(mbTrans297)
factory249 := thrift.NewTJSONProtocolFactory()
jsProt250 := factory249.GetProtocol(mbTrans247)
argvalue0 := aurora.NewJobUpdateKey()
err301 := argvalue0.Read(context.Background(), jsProt300)
if err301 != nil {
err251 := argvalue0.Read(jsProt250)
if err251 != nil {
Usage()
return
}
@ -528,19 +528,19 @@ func main() {
fmt.Fprintln(os.Stderr, "RollbackJobUpdate requires 2 args")
flag.Usage()
}
arg303 := flag.Arg(1)
mbTrans304 := thrift.NewTMemoryBufferLen(len(arg303))
defer mbTrans304.Close()
_, err305 := mbTrans304.WriteString(arg303)
if err305 != nil {
arg253 := flag.Arg(1)
mbTrans254 := thrift.NewTMemoryBufferLen(len(arg253))
defer mbTrans254.Close()
_, err255 := mbTrans254.WriteString(arg253)
if err255 != nil {
Usage()
return
}
factory306 := thrift.NewTJSONProtocolFactory()
jsProt307 := factory306.GetProtocol(mbTrans304)
factory256 := thrift.NewTJSONProtocolFactory()
jsProt257 := factory256.GetProtocol(mbTrans254)
argvalue0 := aurora.NewJobUpdateKey()
err308 := argvalue0.Read(context.Background(), jsProt307)
if err308 != nil {
err258 := argvalue0.Read(jsProt257)
if err258 != nil {
Usage()
return
}
@ -555,19 +555,19 @@ func main() {
fmt.Fprintln(os.Stderr, "PulseJobUpdate requires 1 args")
flag.Usage()
}
arg310 := flag.Arg(1)
mbTrans311 := thrift.NewTMemoryBufferLen(len(arg310))
defer mbTrans311.Close()
_, err312 := mbTrans311.WriteString(arg310)
if err312 != nil {
arg260 := flag.Arg(1)
mbTrans261 := thrift.NewTMemoryBufferLen(len(arg260))
defer mbTrans261.Close()
_, err262 := mbTrans261.WriteString(arg260)
if err262 != nil {
Usage()
return
}
factory313 := thrift.NewTJSONProtocolFactory()
jsProt314 := factory313.GetProtocol(mbTrans311)
factory263 := thrift.NewTJSONProtocolFactory()
jsProt264 := factory263.GetProtocol(mbTrans261)
argvalue0 := aurora.NewJobUpdateKey()
err315 := argvalue0.Read(context.Background(), jsProt314)
if err315 != nil {
err265 := argvalue0.Read(jsProt264)
if err265 != nil {
Usage()
return
}
@ -598,19 +598,19 @@ func main() {
fmt.Fprintln(os.Stderr, "GetTasksStatus requires 1 args")
flag.Usage()
}
arg317 := flag.Arg(1)
mbTrans318 := thrift.NewTMemoryBufferLen(len(arg317))
defer mbTrans318.Close()
_, err319 := mbTrans318.WriteString(arg317)
if err319 != nil {
arg267 := flag.Arg(1)
mbTrans268 := thrift.NewTMemoryBufferLen(len(arg267))
defer mbTrans268.Close()
_, err269 := mbTrans268.WriteString(arg267)
if err269 != nil {
Usage()
return
}
factory320 := thrift.NewTJSONProtocolFactory()
jsProt321 := factory320.GetProtocol(mbTrans318)
factory270 := thrift.NewTJSONProtocolFactory()
jsProt271 := factory270.GetProtocol(mbTrans268)
argvalue0 := aurora.NewTaskQuery()
err322 := argvalue0.Read(context.Background(), jsProt321)
if err322 != nil {
err272 := argvalue0.Read(jsProt271)
if err272 != nil {
Usage()
return
}
@ -623,19 +623,19 @@ func main() {
fmt.Fprintln(os.Stderr, "GetTasksWithoutConfigs requires 1 args")
flag.Usage()
}
arg323 := flag.Arg(1)
mbTrans324 := thrift.NewTMemoryBufferLen(len(arg323))
defer mbTrans324.Close()
_, err325 := mbTrans324.WriteString(arg323)
if err325 != nil {
arg273 := flag.Arg(1)
mbTrans274 := thrift.NewTMemoryBufferLen(len(arg273))
defer mbTrans274.Close()
_, err275 := mbTrans274.WriteString(arg273)
if err275 != nil {
Usage()
return
}
factory326 := thrift.NewTJSONProtocolFactory()
jsProt327 := factory326.GetProtocol(mbTrans324)
factory276 := thrift.NewTJSONProtocolFactory()
jsProt277 := factory276.GetProtocol(mbTrans274)
argvalue0 := aurora.NewTaskQuery()
err328 := argvalue0.Read(context.Background(), jsProt327)
if err328 != nil {
err278 := argvalue0.Read(jsProt277)
if err278 != nil {
Usage()
return
}
@ -648,19 +648,19 @@ func main() {
fmt.Fprintln(os.Stderr, "GetPendingReason requires 1 args")
flag.Usage()
}
arg329 := flag.Arg(1)
mbTrans330 := thrift.NewTMemoryBufferLen(len(arg329))
defer mbTrans330.Close()
_, err331 := mbTrans330.WriteString(arg329)
if err331 != nil {
arg279 := flag.Arg(1)
mbTrans280 := thrift.NewTMemoryBufferLen(len(arg279))
defer mbTrans280.Close()
_, err281 := mbTrans280.WriteString(arg279)
if err281 != nil {
Usage()
return
}
factory332 := thrift.NewTJSONProtocolFactory()
jsProt333 := factory332.GetProtocol(mbTrans330)
factory282 := thrift.NewTJSONProtocolFactory()
jsProt283 := factory282.GetProtocol(mbTrans280)
argvalue0 := aurora.NewTaskQuery()
err334 := argvalue0.Read(context.Background(), jsProt333)
if err334 != nil {
err284 := argvalue0.Read(jsProt283)
if err284 != nil {
Usage()
return
}
@ -673,19 +673,19 @@ func main() {
fmt.Fprintln(os.Stderr, "GetConfigSummary requires 1 args")
flag.Usage()
}
arg335 := flag.Arg(1)
mbTrans336 := thrift.NewTMemoryBufferLen(len(arg335))
defer mbTrans336.Close()
_, err337 := mbTrans336.WriteString(arg335)
if err337 != nil {
arg285 := flag.Arg(1)
mbTrans286 := thrift.NewTMemoryBufferLen(len(arg285))
defer mbTrans286.Close()
_, err287 := mbTrans286.WriteString(arg285)
if err287 != nil {
Usage()
return
}
factory338 := thrift.NewTJSONProtocolFactory()
jsProt339 := factory338.GetProtocol(mbTrans336)
factory288 := thrift.NewTJSONProtocolFactory()
jsProt289 := factory288.GetProtocol(mbTrans286)
argvalue0 := aurora.NewJobKey()
err340 := argvalue0.Read(context.Background(), jsProt339)
if err340 != nil {
err290 := argvalue0.Read(jsProt289)
if err290 != nil {
Usage()
return
}
@ -718,19 +718,19 @@ func main() {
fmt.Fprintln(os.Stderr, "PopulateJobConfig requires 1 args")
flag.Usage()
}
arg343 := flag.Arg(1)
mbTrans344 := thrift.NewTMemoryBufferLen(len(arg343))
defer mbTrans344.Close()
_, err345 := mbTrans344.WriteString(arg343)
if err345 != nil {
arg293 := flag.Arg(1)
mbTrans294 := thrift.NewTMemoryBufferLen(len(arg293))
defer mbTrans294.Close()
_, err295 := mbTrans294.WriteString(arg293)
if err295 != nil {
Usage()
return
}
factory346 := thrift.NewTJSONProtocolFactory()
jsProt347 := factory346.GetProtocol(mbTrans344)
factory296 := thrift.NewTJSONProtocolFactory()
jsProt297 := factory296.GetProtocol(mbTrans294)
argvalue0 := aurora.NewJobConfiguration()
err348 := argvalue0.Read(context.Background(), jsProt347)
if err348 != nil {
err298 := argvalue0.Read(jsProt297)
if err298 != nil {
Usage()
return
}
@ -743,19 +743,19 @@ func main() {
fmt.Fprintln(os.Stderr, "GetJobUpdateSummaries requires 1 args")
flag.Usage()
}
arg349 := flag.Arg(1)
mbTrans350 := thrift.NewTMemoryBufferLen(len(arg349))
defer mbTrans350.Close()
_, err351 := mbTrans350.WriteString(arg349)
if err351 != nil {
arg299 := flag.Arg(1)
mbTrans300 := thrift.NewTMemoryBufferLen(len(arg299))
defer mbTrans300.Close()
_, err301 := mbTrans300.WriteString(arg299)
if err301 != nil {
Usage()
return
}
factory352 := thrift.NewTJSONProtocolFactory()
jsProt353 := factory352.GetProtocol(mbTrans350)
factory302 := thrift.NewTJSONProtocolFactory()
jsProt303 := factory302.GetProtocol(mbTrans300)
argvalue0 := aurora.NewJobUpdateQuery()
err354 := argvalue0.Read(context.Background(), jsProt353)
if err354 != nil {
err304 := argvalue0.Read(jsProt303)
if err304 != nil {
Usage()
return
}
@ -768,19 +768,19 @@ func main() {
fmt.Fprintln(os.Stderr, "GetJobUpdateDetails requires 1 args")
flag.Usage()
}
arg355 := flag.Arg(1)
mbTrans356 := thrift.NewTMemoryBufferLen(len(arg355))
defer mbTrans356.Close()
_, err357 := mbTrans356.WriteString(arg355)
if err357 != nil {
arg305 := flag.Arg(1)
mbTrans306 := thrift.NewTMemoryBufferLen(len(arg305))
defer mbTrans306.Close()
_, err307 := mbTrans306.WriteString(arg305)
if err307 != nil {
Usage()
return
}
factory358 := thrift.NewTJSONProtocolFactory()
jsProt359 := factory358.GetProtocol(mbTrans356)
factory308 := thrift.NewTJSONProtocolFactory()
jsProt309 := factory308.GetProtocol(mbTrans306)
argvalue0 := aurora.NewJobUpdateQuery()
err360 := argvalue0.Read(context.Background(), jsProt359)
if err360 != nil {
err310 := argvalue0.Read(jsProt309)
if err310 != nil {
Usage()
return
}
@ -793,19 +793,19 @@ func main() {
fmt.Fprintln(os.Stderr, "GetJobUpdateDiff requires 1 args")
flag.Usage()
}
arg361 := flag.Arg(1)
mbTrans362 := thrift.NewTMemoryBufferLen(len(arg361))
defer mbTrans362.Close()
_, err363 := mbTrans362.WriteString(arg361)
if err363 != nil {
arg311 := flag.Arg(1)
mbTrans312 := thrift.NewTMemoryBufferLen(len(arg311))
defer mbTrans312.Close()
_, err313 := mbTrans312.WriteString(arg311)
if err313 != nil {
Usage()
return
}
factory364 := thrift.NewTJSONProtocolFactory()
jsProt365 := factory364.GetProtocol(mbTrans362)
factory314 := thrift.NewTJSONProtocolFactory()
jsProt315 := factory314.GetProtocol(mbTrans312)
argvalue0 := aurora.NewJobUpdateRequest()
err366 := argvalue0.Read(context.Background(), jsProt365)
if err366 != nil {
err316 := argvalue0.Read(jsProt315)
if err316 != nil {
Usage()
return
}

View file

@ -1,22 +1,22 @@
// Code generated by Thrift Compiler (0.14.0). DO NOT EDIT.
// Autogenerated by Thrift Compiler (0.12.0)
// DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING
package main
import (
"context"
"flag"
"fmt"
"math"
"net"
"net/url"
"os"
"strconv"
"strings"
"github.com/apache/thrift/lib/go/thrift"
"apache/aurora"
"context"
"flag"
"fmt"
"math"
"net"
"net/url"
"os"
"strconv"
"strings"
"github.com/apache/thrift/lib/go/thrift"
"apache/aurora"
)
var _ = aurora.GoUnusedProtection__
func Usage() {
fmt.Fprintln(os.Stderr, "Usage of ", os.Args[0], " [-h host:port] [-u url] [-f[ramed]] function [arg1 [arg2...]]:")
@ -179,19 +179,19 @@ func main() {
fmt.Fprintln(os.Stderr, "GetTasksStatus requires 1 args")
flag.Usage()
}
arg132 := flag.Arg(1)
mbTrans133 := thrift.NewTMemoryBufferLen(len(arg132))
defer mbTrans133.Close()
_, err134 := mbTrans133.WriteString(arg132)
if err134 != nil {
arg82 := flag.Arg(1)
mbTrans83 := thrift.NewTMemoryBufferLen(len(arg82))
defer mbTrans83.Close()
_, err84 := mbTrans83.WriteString(arg82)
if err84 != nil {
Usage()
return
}
factory135 := thrift.NewTJSONProtocolFactory()
jsProt136 := factory135.GetProtocol(mbTrans133)
factory85 := thrift.NewTJSONProtocolFactory()
jsProt86 := factory85.GetProtocol(mbTrans83)
argvalue0 := aurora.NewTaskQuery()
err137 := argvalue0.Read(context.Background(), jsProt136)
if err137 != nil {
err87 := argvalue0.Read(jsProt86)
if err87 != nil {
Usage()
return
}
@ -204,19 +204,19 @@ func main() {
fmt.Fprintln(os.Stderr, "GetTasksWithoutConfigs requires 1 args")
flag.Usage()
}
arg138 := flag.Arg(1)
mbTrans139 := thrift.NewTMemoryBufferLen(len(arg138))
defer mbTrans139.Close()
_, err140 := mbTrans139.WriteString(arg138)
if err140 != nil {
arg88 := flag.Arg(1)
mbTrans89 := thrift.NewTMemoryBufferLen(len(arg88))
defer mbTrans89.Close()
_, err90 := mbTrans89.WriteString(arg88)
if err90 != nil {
Usage()
return
}
factory141 := thrift.NewTJSONProtocolFactory()
jsProt142 := factory141.GetProtocol(mbTrans139)
factory91 := thrift.NewTJSONProtocolFactory()
jsProt92 := factory91.GetProtocol(mbTrans89)
argvalue0 := aurora.NewTaskQuery()
err143 := argvalue0.Read(context.Background(), jsProt142)
if err143 != nil {
err93 := argvalue0.Read(jsProt92)
if err93 != nil {
Usage()
return
}
@ -229,19 +229,19 @@ func main() {
fmt.Fprintln(os.Stderr, "GetPendingReason requires 1 args")
flag.Usage()
}
arg144 := flag.Arg(1)
mbTrans145 := thrift.NewTMemoryBufferLen(len(arg144))
defer mbTrans145.Close()
_, err146 := mbTrans145.WriteString(arg144)
if err146 != nil {
arg94 := flag.Arg(1)
mbTrans95 := thrift.NewTMemoryBufferLen(len(arg94))
defer mbTrans95.Close()
_, err96 := mbTrans95.WriteString(arg94)
if err96 != nil {
Usage()
return
}
factory147 := thrift.NewTJSONProtocolFactory()
jsProt148 := factory147.GetProtocol(mbTrans145)
factory97 := thrift.NewTJSONProtocolFactory()
jsProt98 := factory97.GetProtocol(mbTrans95)
argvalue0 := aurora.NewTaskQuery()
err149 := argvalue0.Read(context.Background(), jsProt148)
if err149 != nil {
err99 := argvalue0.Read(jsProt98)
if err99 != nil {
Usage()
return
}
@ -254,19 +254,19 @@ func main() {
fmt.Fprintln(os.Stderr, "GetConfigSummary requires 1 args")
flag.Usage()
}
arg150 := flag.Arg(1)
mbTrans151 := thrift.NewTMemoryBufferLen(len(arg150))
defer mbTrans151.Close()
_, err152 := mbTrans151.WriteString(arg150)
if err152 != nil {
arg100 := flag.Arg(1)
mbTrans101 := thrift.NewTMemoryBufferLen(len(arg100))
defer mbTrans101.Close()
_, err102 := mbTrans101.WriteString(arg100)
if err102 != nil {
Usage()
return
}
factory153 := thrift.NewTJSONProtocolFactory()
jsProt154 := factory153.GetProtocol(mbTrans151)
factory103 := thrift.NewTJSONProtocolFactory()
jsProt104 := factory103.GetProtocol(mbTrans101)
argvalue0 := aurora.NewJobKey()
err155 := argvalue0.Read(context.Background(), jsProt154)
if err155 != nil {
err105 := argvalue0.Read(jsProt104)
if err105 != nil {
Usage()
return
}
@ -299,19 +299,19 @@ func main() {
fmt.Fprintln(os.Stderr, "PopulateJobConfig requires 1 args")
flag.Usage()
}
arg158 := flag.Arg(1)
mbTrans159 := thrift.NewTMemoryBufferLen(len(arg158))
defer mbTrans159.Close()
_, err160 := mbTrans159.WriteString(arg158)
if err160 != nil {
arg108 := flag.Arg(1)
mbTrans109 := thrift.NewTMemoryBufferLen(len(arg108))
defer mbTrans109.Close()
_, err110 := mbTrans109.WriteString(arg108)
if err110 != nil {
Usage()
return
}
factory161 := thrift.NewTJSONProtocolFactory()
jsProt162 := factory161.GetProtocol(mbTrans159)
factory111 := thrift.NewTJSONProtocolFactory()
jsProt112 := factory111.GetProtocol(mbTrans109)
argvalue0 := aurora.NewJobConfiguration()
err163 := argvalue0.Read(context.Background(), jsProt162)
if err163 != nil {
err113 := argvalue0.Read(jsProt112)
if err113 != nil {
Usage()
return
}
@ -324,19 +324,19 @@ func main() {
fmt.Fprintln(os.Stderr, "GetJobUpdateSummaries requires 1 args")
flag.Usage()
}
arg164 := flag.Arg(1)
mbTrans165 := thrift.NewTMemoryBufferLen(len(arg164))
defer mbTrans165.Close()
_, err166 := mbTrans165.WriteString(arg164)
if err166 != nil {
arg114 := flag.Arg(1)
mbTrans115 := thrift.NewTMemoryBufferLen(len(arg114))
defer mbTrans115.Close()
_, err116 := mbTrans115.WriteString(arg114)
if err116 != nil {
Usage()
return
}
factory167 := thrift.NewTJSONProtocolFactory()
jsProt168 := factory167.GetProtocol(mbTrans165)
factory117 := thrift.NewTJSONProtocolFactory()
jsProt118 := factory117.GetProtocol(mbTrans115)
argvalue0 := aurora.NewJobUpdateQuery()
err169 := argvalue0.Read(context.Background(), jsProt168)
if err169 != nil {
err119 := argvalue0.Read(jsProt118)
if err119 != nil {
Usage()
return
}
@ -349,19 +349,19 @@ func main() {
fmt.Fprintln(os.Stderr, "GetJobUpdateDetails requires 1 args")
flag.Usage()
}
arg170 := flag.Arg(1)
mbTrans171 := thrift.NewTMemoryBufferLen(len(arg170))
defer mbTrans171.Close()
_, err172 := mbTrans171.WriteString(arg170)
if err172 != nil {
arg120 := flag.Arg(1)
mbTrans121 := thrift.NewTMemoryBufferLen(len(arg120))
defer mbTrans121.Close()
_, err122 := mbTrans121.WriteString(arg120)
if err122 != nil {
Usage()
return
}
factory173 := thrift.NewTJSONProtocolFactory()
jsProt174 := factory173.GetProtocol(mbTrans171)
factory123 := thrift.NewTJSONProtocolFactory()
jsProt124 := factory123.GetProtocol(mbTrans121)
argvalue0 := aurora.NewJobUpdateQuery()
err175 := argvalue0.Read(context.Background(), jsProt174)
if err175 != nil {
err125 := argvalue0.Read(jsProt124)
if err125 != nil {
Usage()
return
}
@ -374,19 +374,19 @@ func main() {
fmt.Fprintln(os.Stderr, "GetJobUpdateDiff requires 1 args")
flag.Usage()
}
arg176 := flag.Arg(1)
mbTrans177 := thrift.NewTMemoryBufferLen(len(arg176))
defer mbTrans177.Close()
_, err178 := mbTrans177.WriteString(arg176)
if err178 != nil {
arg126 := flag.Arg(1)
mbTrans127 := thrift.NewTMemoryBufferLen(len(arg126))
defer mbTrans127.Close()
_, err128 := mbTrans127.WriteString(arg126)
if err128 != nil {
Usage()
return
}
factory179 := thrift.NewTJSONProtocolFactory()
jsProt180 := factory179.GetProtocol(mbTrans177)
factory129 := thrift.NewTJSONProtocolFactory()
jsProt130 := factory129.GetProtocol(mbTrans127)
argvalue0 := aurora.NewJobUpdateRequest()
err181 := argvalue0.Read(context.Background(), jsProt180)
if err181 != nil {
err131 := argvalue0.Read(jsProt130)
if err131 != nil {
Usage()
return
}

View file

@ -1,6 +1,6 @@
#! /bin/bash
THRIFT_VER=0.14.0
THRIFT_VER=0.12.0
if [[ $(thrift -version | grep -e $THRIFT_VER -c) -ne 1 ]]; then
echo "Warning: This wrapper has only been tested with version" $THRIFT_VER;

14
go.mod
View file

@ -1,10 +1,12 @@
module github.com/aurora-scheduler/gorealis/v2
module github.com/paypal/gorealis
go 1.12
require (
github.com/apache/thrift v0.14.0
github.com/pkg/errors v0.9.1
github.com/apache/thrift v0.12.0
github.com/davecgh/go-spew v1.1.0
github.com/pkg/errors v0.0.0-20171216070316-e881fd58d78e
github.com/pmezard/go-difflib v1.0.0
github.com/samuel/go-zookeeper v0.0.0-20171117190445-471cd4e61d7a
github.com/stretchr/testify v1.7.0
github.com/stretchr/testify v1.2.0
)
go 1.16

23
go.sum
View file

@ -1,22 +1,9 @@
github.com/apache/thrift v0.14.0 h1:vqZ2DP42i8th2OsgCcYZkirtbzvpZEFx53LiWDJXIAs=
github.com/apache/thrift v0.14.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ=
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/apache/thrift v0.12.0 h1:pODnxUFNcjP9UTLZGTdeh+j16A8lJbRvD3rOtrk/7bs=
github.com/apache/thrift v0.12.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pkg/errors v0.0.0-20171216070316-e881fd58d78e h1:+RHxT/gm0O3UF7nLJbdNzAmULvCFt4XfXHWzh3XI/zs=
github.com/pkg/errors v0.0.0-20171216070316-e881fd58d78e/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/samuel/go-zookeeper v0.0.0-20171117190445-471cd4e61d7a h1:EYL2xz/Zdo0hyqdZMXR4lmT2O11jDLTPCEqIe/FR6W4=
github.com/samuel/go-zookeeper v0.0.0-20171117190445-471cd4e61d7a/go.mod h1:gi+0XIa01GRL2eRQVjQkKGqKF3SF9vZR/HnPullcV2E=
github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.5.0 h1:DMOzIV76tmoDNE9pX6RSN0aDtCYeCg5VueieJaAo1uw=
github.com/stretchr/testify v1.5.0/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
github.com/stretchr/testify v1.2.0/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=

View file

@ -1,23 +0,0 @@
package realis
import (
"context"
"github.com/aurora-scheduler/gorealis/v2/gen-go/apache/aurora"
)
func (r *Client) JobExists(key aurora.JobKey) (bool, error) {
resp, err := r.client.GetConfigSummary(context.TODO(), &key)
if err != nil {
return false, err
}
return resp != nil &&
resp.GetResult_() != nil &&
resp.GetResult_().GetConfigSummaryResult_() != nil &&
resp.GetResult_().GetConfigSummaryResult_().GetSummary() != nil &&
resp.GetResult_().GetConfigSummaryResult_().GetSummary().GetGroups() != nil &&
len(resp.GetResult_().GetConfigSummaryResult_().GetSummary().GetGroups()) > 0 &&
resp.GetResponseCode() == aurora.ResponseCode_OK,
nil
}

403
job.go
View file

@ -15,213 +15,358 @@
package realis
import (
"github.com/aurora-scheduler/gorealis/v2/gen-go/apache/aurora"
"strconv"
"github.com/paypal/gorealis/gen-go/apache/aurora"
)
// Structure to collect all information pertaining to an Aurora job.
type AuroraJob struct {
jobConfig *aurora.JobConfiguration
task *AuroraTask
// Job inteface is used to define a set of functions an Aurora Job object
// must implemement.
// TODO(rdelvalle): Consider getting rid of the Job interface
type Job interface {
// Set Job Key environment.
Environment(env string) Job
Role(role string) Job
Name(name string) Job
CronSchedule(cron string) Job
CronCollisionPolicy(policy aurora.CronCollisionPolicy) Job
CPU(cpus float64) Job
Disk(disk int64) Job
RAM(ram int64) Job
GPU(gpu int64) Job
ExecutorName(name string) Job
ExecutorData(data string) Job
AddPorts(num int) Job
AddLabel(key string, value string) Job
AddNamedPorts(names ...string) Job
AddLimitConstraint(name string, limit int32) Job
AddValueConstraint(name string, negated bool, values ...string) Job
// From Aurora Docs:
// dedicated attribute. Aurora treats this specially, and only allows matching jobs
// to run on these machines, and will only schedule matching jobs on these machines.
// When a job is created, the scheduler requires that the $role component matches
// the role field in the job configuration, and will reject the job creation otherwise.
// A wildcard (*) may be used for the role portion of the dedicated attribute, which
// will allow any owner to elect for a job to run on the host(s)
AddDedicatedConstraint(role, name string) Job
AddURIs(extract bool, cache bool, values ...string) Job
JobKey() *aurora.JobKey
JobConfig() *aurora.JobConfiguration
TaskConfig() *aurora.TaskConfig
IsService(isService bool) Job
InstanceCount(instCount int32) Job
GetInstanceCount() int32
MaxFailure(maxFail int32) Job
Container(container Container) Job
PartitionPolicy(policy *aurora.PartitionPolicy) Job
Tier(tier string) Job
SlaPolicy(policy *aurora.SlaPolicy) Job
}
// Create a AuroraJob object with everything initialized.
func NewJob() *AuroraJob {
type resourceType int
jobKey := &aurora.JobKey{}
const (
CPU resourceType = iota
RAM
DISK
GPU
)
// AuroraTask clientConfig
task := NewTask()
task.task.Job = jobKey
// AuroraJob is a structure to collect all information pertaining to an Aurora job.
type AuroraJob struct {
jobConfig *aurora.JobConfiguration
resources map[resourceType]*aurora.Resource
metadata map[string]*aurora.Metadata
portCount int
}
// AuroraJob clientConfig
jobConfig := &aurora.JobConfiguration{
Key: jobKey,
TaskConfig: task.TaskConfig(),
}
// NewJob is used to create a Job object with everything initialized.
func NewJob() Job {
jobConfig := aurora.NewJobConfiguration()
taskConfig := aurora.NewTaskConfig()
jobKey := aurora.NewJobKey()
// Job Config
jobConfig.Key = jobKey
jobConfig.TaskConfig = taskConfig
// Task Config
taskConfig.Job = jobKey
taskConfig.Container = aurora.NewContainer()
taskConfig.Container.Mesos = aurora.NewMesosContainer()
// Resources
numCpus := aurora.NewResource()
ramMb := aurora.NewResource()
diskMb := aurora.NewResource()
resources := map[resourceType]*aurora.Resource{CPU: numCpus, RAM: ramMb, DISK: diskMb}
taskConfig.Resources = []*aurora.Resource{numCpus, ramMb, diskMb}
numCpus.NumCpus = new(float64)
ramMb.RamMb = new(int64)
diskMb.DiskMb = new(int64)
return &AuroraJob{
jobConfig: jobConfig,
task: task,
resources: resources,
metadata: make(map[string]*aurora.Metadata),
portCount: 0,
}
}
// Set AuroraJob Key environment. Explicit changes to AuroraTask's job key are not needed
// because they share a pointer to the same JobKey.
func (j *AuroraJob) Environment(env string) *AuroraJob {
// Environment sets the Job Key environment.
func (j *AuroraJob) Environment(env string) Job {
j.jobConfig.Key.Environment = env
return j
}
// Set AuroraJob Key Role.
func (j *AuroraJob) Role(role string) *AuroraJob {
// Role sets the Job Key role.
func (j *AuroraJob) Role(role string) Job {
j.jobConfig.Key.Role = role
// Will be deprecated
identity := &aurora.Identity{User: role}
j.jobConfig.Owner = identity
j.jobConfig.TaskConfig.Owner = identity
return j
}
// Set AuroraJob Key Name.
func (j *AuroraJob) Name(name string) *AuroraJob {
// Name sets the Job Key Name.
func (j *AuroraJob) Name(name string) Job {
j.jobConfig.Key.Name = name
return j
}
// How many instances of the job to run
func (j *AuroraJob) InstanceCount(instCount int32) *AuroraJob {
// ExecutorName sets the name of the executor that will the task will be configured to.
func (j *AuroraJob) ExecutorName(name string) Job {
if j.jobConfig.TaskConfig.ExecutorConfig == nil {
j.jobConfig.TaskConfig.ExecutorConfig = aurora.NewExecutorConfig()
}
j.jobConfig.TaskConfig.ExecutorConfig.Name = name
return j
}
// ExecutorData sets the data blob that will be passed to the Mesos executor.
func (j *AuroraJob) ExecutorData(data string) Job {
if j.jobConfig.TaskConfig.ExecutorConfig == nil {
j.jobConfig.TaskConfig.ExecutorConfig = aurora.NewExecutorConfig()
}
j.jobConfig.TaskConfig.ExecutorConfig.Data = data
return j
}
// CPU sets the amount of CPU each task will use in an Aurora Job.
func (j *AuroraJob) CPU(cpus float64) Job {
*j.resources[CPU].NumCpus = cpus
return j
}
// RAM sets the amount of RAM each task will use in an Aurora Job.
func (j *AuroraJob) RAM(ram int64) Job {
*j.resources[RAM].RamMb = ram
return j
}
// Disk sets the amount of Disk each task will use in an Aurora Job.
func (j *AuroraJob) Disk(disk int64) Job {
*j.resources[DISK].DiskMb = disk
return j
}
// GPU sets the amount of GPU each task will use in an Aurora Job.
func (j *AuroraJob) GPU(gpu int64) Job {
// GPU resource must be set explicitly since the scheduler by default
// rejects jobs with GPU resources attached to it.
if _, ok := j.resources[GPU]; !ok {
j.resources[GPU] = &aurora.Resource{}
j.JobConfig().GetTaskConfig().Resources = append(
j.JobConfig().GetTaskConfig().Resources,
j.resources[GPU])
}
j.resources[GPU].NumGpus = &gpu
return j
}
// MaxFailure sets how many failures to tolerate before giving up per Job.
func (j *AuroraJob) MaxFailure(maxFail int32) Job {
j.jobConfig.TaskConfig.MaxTaskFailures = maxFail
return j
}
// InstanceCount sets how many instances of the task to run for this Job.
func (j *AuroraJob) InstanceCount(instCount int32) Job {
j.jobConfig.InstanceCount = instCount
return j
}
func (j *AuroraJob) CronSchedule(cron string) *AuroraJob {
// CronSchedule allows the user to configure a cron schedule for this job to run in.
func (j *AuroraJob) CronSchedule(cron string) Job {
j.jobConfig.CronSchedule = &cron
return j
}
func (j *AuroraJob) CronCollisionPolicy(policy aurora.CronCollisionPolicy) *AuroraJob {
// CronCollisionPolicy allows the user to decide what happens if two or more instances
// of the same Cron job need to run.
func (j *AuroraJob) CronCollisionPolicy(policy aurora.CronCollisionPolicy) Job {
j.jobConfig.CronCollisionPolicy = policy
return j
}
// How many instances of the job to run
// GetInstanceCount returns how many tasks this Job contains.
func (j *AuroraJob) GetInstanceCount() int32 {
return j.jobConfig.InstanceCount
}
// Get the current job configurations key to use for some realis calls.
func (j *AuroraJob) JobKey() aurora.JobKey {
return *j.jobConfig.Key
// IsService returns true if the job is a long term running job or false if it is an ad-hoc job.
func (j *AuroraJob) IsService(isService bool) Job {
j.jobConfig.TaskConfig.IsService = isService
return j
}
// Get the current job configurations key to use for some realis calls.
// JobKey returns the job's configuration key.
func (j *AuroraJob) JobKey() *aurora.JobKey {
return j.jobConfig.Key
}
// JobConfig returns the job's configuration.
func (j *AuroraJob) JobConfig() *aurora.JobConfiguration {
return j.jobConfig
}
// Get the current job configurations key to use for some realis calls.
func (j *AuroraJob) AuroraTask() *AuroraTask {
return j.task
}
/*
AuroraTask specific API, see task.go for further documentation.
These functions are provided for the convenience of chaining API calls.
*/
func (j *AuroraJob) ExecutorName(name string) *AuroraJob {
j.task.ExecutorName(name)
return j
}
func (j *AuroraJob) ExecutorData(data string) *AuroraJob {
j.task.ExecutorData(data)
return j
}
func (j *AuroraJob) CPU(cpus float64) *AuroraJob {
j.task.CPU(cpus)
return j
}
func (j *AuroraJob) RAM(ram int64) *AuroraJob {
j.task.RAM(ram)
return j
}
func (j *AuroraJob) Disk(disk int64) *AuroraJob {
j.task.Disk(disk)
return j
}
func (j *AuroraJob) GPU(gpu int64) *AuroraJob {
j.task.GPU(gpu)
return j
}
func (j *AuroraJob) Tier(tier string) *AuroraJob {
j.task.Tier(tier)
return j
}
func (j *AuroraJob) MaxFailure(maxFail int32) *AuroraJob {
j.task.MaxFailure(maxFail)
return j
}
func (j *AuroraJob) IsService(isService bool) *AuroraJob {
j.task.IsService(isService)
return j
}
func (j *AuroraJob) Priority(priority int32) *AuroraJob {
j.task.Priority(priority)
return j
}
func (j *AuroraJob) Production(production bool) *AuroraJob {
j.task.Production(production)
return j
}
// TaskConfig returns the job's task(shard) configuration.
func (j *AuroraJob) TaskConfig() *aurora.TaskConfig {
return j.task.TaskConfig()
return j.jobConfig.TaskConfig
}
func (j *AuroraJob) AddURIs(extract bool, cache bool, values ...string) *AuroraJob {
j.task.AddURIs(extract, cache, values...)
// AddURIs adds a list of URIs with the same extract and cache configuration. Scheduler must have
// --enable_mesos_fetcher flag enabled. Currently there is no duplicate detection.
func (j *AuroraJob) AddURIs(extract bool, cache bool, values ...string) Job {
for _, value := range values {
j.jobConfig.TaskConfig.MesosFetcherUris = append(j.jobConfig.TaskConfig.MesosFetcherUris,
&aurora.MesosFetcherURI{Value: value, Extract: &extract, Cache: &cache})
}
return j
}
func (j *AuroraJob) AddLabel(key string, value string) *AuroraJob {
j.task.AddLabel(key, value)
// AddLabel adds a Mesos label to the job. Note that Aurora will add the
// prefix "org.apache.aurora.metadata." to the beginning of each key.
func (j *AuroraJob) AddLabel(key string, value string) Job {
if _, ok := j.metadata[key]; ok {
j.metadata[key].Value = value
} else {
j.metadata[key] = &aurora.Metadata{Key: key, Value: value}
j.jobConfig.TaskConfig.Metadata = append(j.jobConfig.TaskConfig.Metadata, j.metadata[key])
}
return j
}
func (j *AuroraJob) AddNamedPorts(names ...string) *AuroraJob {
j.task.AddNamedPorts(names...)
// AddNamedPorts adds a named port to the job configuration These are random ports as it's
// not currently possible to request specific ports using Aurora.
func (j *AuroraJob) AddNamedPorts(names ...string) Job {
j.portCount += len(names)
for _, name := range names {
j.jobConfig.TaskConfig.Resources = append(
j.jobConfig.TaskConfig.Resources,
&aurora.Resource{NamedPort: &name})
}
return j
}
func (j *AuroraJob) AddPorts(num int) *AuroraJob {
j.task.AddPorts(num)
return j
}
func (j *AuroraJob) AddValueConstraint(name string, negated bool, values ...string) *AuroraJob {
j.task.AddValueConstraint(name, negated, values...)
// AddPorts adds a request for a number of ports to the job configuration. The names chosen for these ports
// will be org.apache.aurora.port.X, where X is the current port count for the job configuration
// starting at 0. These are random ports as it's not currently possible to request
// specific ports using Aurora.
func (j *AuroraJob) AddPorts(num int) Job {
start := j.portCount
j.portCount += num
for i := start; i < j.portCount; i++ {
portName := "org.apache.aurora.port." + strconv.Itoa(i)
j.jobConfig.TaskConfig.Resources = append(
j.jobConfig.TaskConfig.Resources,
&aurora.Resource{NamedPort: &portName})
}
return j
}
func (j *AuroraJob) AddLimitConstraint(name string, limit int32) *AuroraJob {
j.task.AddLimitConstraint(name, limit)
// AddValueConstraint allows the user to add a value constrain to the job to limiti which agents the job's
// tasks can be run on.
// From Aurora Docs:
// Add a Value constraint
// name - Mesos slave attribute that the constraint is matched against.
// If negated = true , treat this as a 'not' - to avoid specific values.
// Values - list of values we look for in attribute name
func (j *AuroraJob) AddValueConstraint(name string, negated bool, values ...string) Job {
j.jobConfig.TaskConfig.Constraints = append(j.jobConfig.TaskConfig.Constraints,
&aurora.Constraint{
Name: name,
Constraint: &aurora.TaskConstraint{
Value: &aurora.ValueConstraint{
Negated: negated,
Values: values,
},
Limit: nil,
},
})
return j
}
func (j *AuroraJob) AddDedicatedConstraint(role, name string) *AuroraJob {
j.task.AddDedicatedConstraint(role, name)
// AddLimitConstraint allows the user to limit how many tasks form the same Job are run on a single host.
// From Aurora Docs:
// A constraint that specifies the maximum number of active tasks on a host with
// a matching attribute that may be scheduled simultaneously.
func (j *AuroraJob) AddLimitConstraint(name string, limit int32) Job {
j.jobConfig.TaskConfig.Constraints = append(j.jobConfig.TaskConfig.Constraints,
&aurora.Constraint{
Name: name,
Constraint: &aurora.TaskConstraint{
Value: nil,
Limit: &aurora.LimitConstraint{Limit: limit},
},
})
return j
}
func (j *AuroraJob) Container(container Container) *AuroraJob {
j.task.Container(container)
// AddDedicatedConstraint allows the user to add a dedicated constraint to a Job configuration.
func (j *AuroraJob) AddDedicatedConstraint(role, name string) Job {
j.AddValueConstraint("dedicated", false, role+"/"+name)
return j
}
func (j *AuroraJob) ThermosExecutor(thermos ThermosExecutor) *AuroraJob {
j.task.ThermosExecutor(thermos)
// Container sets a container to run for the job configuration to run.
func (j *AuroraJob) Container(container Container) Job {
j.jobConfig.TaskConfig.Container = container.Build()
return j
}
func (j *AuroraJob) BuildThermosPayload() error {
return j.task.BuildThermosPayload()
}
func (j *AuroraJob) PartitionPolicy(reschedule bool, delay int64) *AuroraJob {
j.task.PartitionPolicy(aurora.PartitionPolicy{
Reschedule: reschedule,
DelaySecs: &delay,
})
// PartitionPolicy sets a partition policy for the job configuration to implement.
func (j *AuroraJob) PartitionPolicy(policy *aurora.PartitionPolicy) Job {
j.jobConfig.TaskConfig.PartitionPolicy = policy
return j
}
// Tier sets the Tier for the Job.
func (j *AuroraJob) Tier(tier string) Job {
j.jobConfig.TaskConfig.Tier = &tier
return j
}
// SlaPolicy sets an SlaPolicy for the Job.
func (j *AuroraJob) SlaPolicy(policy *aurora.SlaPolicy) Job {
j.jobConfig.TaskConfig.SlaPolicy = policy
return j
}

View file

@ -1,296 +0,0 @@
/**
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package realis
import (
"time"
"github.com/apache/thrift/lib/go/thrift"
"github.com/aurora-scheduler/gorealis/v2/gen-go/apache/aurora"
)
// Structure to collect all information required to create job update
type JobUpdate struct {
task *AuroraTask
request *aurora.JobUpdateRequest
}
// Create a default JobUpdate object with an empty task and no fields filled in.
func NewJobUpdate() *JobUpdate {
newTask := NewTask()
return &JobUpdate{
task: newTask,
request: &aurora.JobUpdateRequest{TaskConfig: newTask.TaskConfig(), Settings: newUpdateSettings()},
}
}
// Creates an update with default values using an AuroraTask as the underlying task configuration.
// This function has a high level understanding of Aurora Tasks and thus will support copying a task that is configured
// to use Thermos.
func JobUpdateFromAuroraTask(task *AuroraTask) *JobUpdate {
newTask := task.Clone()
return &JobUpdate{
task: newTask,
request: &aurora.JobUpdateRequest{TaskConfig: newTask.TaskConfig(), Settings: newUpdateSettings()},
}
}
// JobUpdateFromConfig creates an update with default values using an aurora.TaskConfig
// primitive as the underlying task configuration.
// This function should not be used unless the implications of using a primitive value are understood.
// For example, the primitive has no concept of Thermos.
func JobUpdateFromConfig(task *aurora.TaskConfig) *JobUpdate {
// Perform a deep copy to avoid unexpected behavior
newTask := TaskFromThrift(task)
return &JobUpdate{
task: newTask,
request: &aurora.JobUpdateRequest{TaskConfig: newTask.TaskConfig(), Settings: newUpdateSettings()},
}
}
// Set instance count the job will have after the update.
func (j *JobUpdate) InstanceCount(inst int32) *JobUpdate {
j.request.InstanceCount = inst
return j
}
// Max number of instances being updated at any given moment.
func (j *JobUpdate) BatchSize(size int32) *JobUpdate {
j.request.Settings.UpdateGroupSize = size
return j
}
// Minimum number of seconds a shard must remain in RUNNING state before considered a success.
func (j *JobUpdate) WatchTime(timeout time.Duration) *JobUpdate {
j.request.Settings.MinWaitInInstanceRunningMs = int32(timeout.Milliseconds())
return j
}
// Wait for all instances in a group to be done before moving on.
func (j *JobUpdate) WaitForBatchCompletion(batchWait bool) *JobUpdate {
j.request.Settings.WaitForBatchCompletion = batchWait
return j
}
// Max number of instance failures to tolerate before marking instance as FAILED.
func (j *JobUpdate) MaxPerInstanceFailures(inst int32) *JobUpdate {
j.request.Settings.MaxPerInstanceFailures = inst
return j
}
// Max number of FAILED instances to tolerate before terminating the update.
func (j *JobUpdate) MaxFailedInstances(inst int32) *JobUpdate {
j.request.Settings.MaxFailedInstances = inst
return j
}
// When False, prevents auto rollback of a failed update.
func (j *JobUpdate) RollbackOnFail(rollback bool) *JobUpdate {
j.request.Settings.RollbackOnFailure = rollback
return j
}
// Sets the interval at which pulses should be received by the job update before timing out.
func (j *JobUpdate) PulseIntervalTimeout(timeout time.Duration) *JobUpdate {
j.request.Settings.BlockIfNoPulsesAfterMs = thrift.Int32Ptr(int32(timeout.Seconds() * 1000))
return j
}
func (j *JobUpdate) BatchUpdateStrategy(autoPause bool, batchSize int32) *JobUpdate {
j.request.Settings.UpdateStrategy = &aurora.JobUpdateStrategy{
BatchStrategy: &aurora.BatchJobUpdateStrategy{GroupSize: batchSize, AutopauseAfterBatch: autoPause},
}
return j
}
func (j *JobUpdate) QueueUpdateStrategy(groupSize int32) *JobUpdate {
j.request.Settings.UpdateStrategy = &aurora.JobUpdateStrategy{
QueueStrategy: &aurora.QueueJobUpdateStrategy{GroupSize: groupSize},
}
return j
}
func (j *JobUpdate) VariableBatchStrategy(autoPause bool, batchSizes ...int32) *JobUpdate {
j.request.Settings.UpdateStrategy = &aurora.JobUpdateStrategy{
VarBatchStrategy: &aurora.VariableBatchJobUpdateStrategy{GroupSizes: batchSizes, AutopauseAfterBatch: autoPause},
}
return j
}
// SlaAware makes the scheduler enforce the SLA Aware policy if the job meets the SLA awareness criteria.
// By default, the scheduler will only apply SLA Awareness to jobs in the production tier with 20 or more instances.
func (j *JobUpdate) SlaAware(slaAware bool) *JobUpdate {
j.request.Settings.SlaAware = &slaAware
return j
}
// AddInstanceRange allows updates to only touch a certain specific range of instances
func (j *JobUpdate) AddInstanceRange(first, last int32) *JobUpdate {
j.request.Settings.UpdateOnlyTheseInstances = append(j.request.Settings.UpdateOnlyTheseInstances,
&aurora.Range{First: first, Last: last})
return j
}
func newUpdateSettings() *aurora.JobUpdateSettings {
us := aurora.JobUpdateSettings{}
// Mirrors defaults set by Pystachio
us.UpdateOnlyTheseInstances = []*aurora.Range{}
us.UpdateGroupSize = 1
us.WaitForBatchCompletion = false
us.MinWaitInInstanceRunningMs = 45000
us.MaxPerInstanceFailures = 0
us.MaxFailedInstances = 0
us.RollbackOnFailure = true
return &us
}
/*
These methods are provided for user convenience in order to chain
calls for configuration.
API below here are wrappers around modifying an AuroraTask instance.
See task.go for further documentation.
*/
func (j *JobUpdate) Environment(env string) *JobUpdate {
j.task.Environment(env)
return j
}
func (j *JobUpdate) Role(role string) *JobUpdate {
j.task.Role(role)
return j
}
func (j *JobUpdate) Name(name string) *JobUpdate {
j.task.Name(name)
return j
}
func (j *JobUpdate) ExecutorName(name string) *JobUpdate {
j.task.ExecutorName(name)
return j
}
func (j *JobUpdate) ExecutorData(data string) *JobUpdate {
j.task.ExecutorData(data)
return j
}
func (j *JobUpdate) CPU(cpus float64) *JobUpdate {
j.task.CPU(cpus)
return j
}
func (j *JobUpdate) RAM(ram int64) *JobUpdate {
j.task.RAM(ram)
return j
}
func (j *JobUpdate) Disk(disk int64) *JobUpdate {
j.task.Disk(disk)
return j
}
func (j *JobUpdate) Tier(tier string) *JobUpdate {
j.task.Tier(tier)
return j
}
func (j *JobUpdate) TaskMaxFailure(maxFail int32) *JobUpdate {
j.task.MaxFailure(maxFail)
return j
}
func (j *JobUpdate) IsService(isService bool) *JobUpdate {
j.task.IsService(isService)
return j
}
func (j *JobUpdate) Priority(priority int32) *JobUpdate {
j.task.Priority(priority)
return j
}
func (j *JobUpdate) Production(production bool) *JobUpdate {
j.task.Production(production)
return j
}
func (j *JobUpdate) TaskConfig() *aurora.TaskConfig {
return j.task.TaskConfig()
}
func (j *JobUpdate) AddURIs(extract bool, cache bool, values ...string) *JobUpdate {
j.task.AddURIs(extract, cache, values...)
return j
}
func (j *JobUpdate) AddLabel(key string, value string) *JobUpdate {
j.task.AddLabel(key, value)
return j
}
func (j *JobUpdate) AddNamedPorts(names ...string) *JobUpdate {
j.task.AddNamedPorts(names...)
return j
}
func (j *JobUpdate) AddPorts(num int) *JobUpdate {
j.task.AddPorts(num)
return j
}
func (j *JobUpdate) AddValueConstraint(name string, negated bool, values ...string) *JobUpdate {
j.task.AddValueConstraint(name, negated, values...)
return j
}
func (j *JobUpdate) AddLimitConstraint(name string, limit int32) *JobUpdate {
j.task.AddLimitConstraint(name, limit)
return j
}
func (j *JobUpdate) AddDedicatedConstraint(role, name string) *JobUpdate {
j.task.AddDedicatedConstraint(role, name)
return j
}
func (j *JobUpdate) Container(container Container) *JobUpdate {
j.task.Container(container)
return j
}
func (j *JobUpdate) JobKey() aurora.JobKey {
return j.task.JobKey()
}
func (j *JobUpdate) ThermosExecutor(thermos ThermosExecutor) *JobUpdate {
j.task.ThermosExecutor(thermos)
return j
}
func (j *JobUpdate) BuildThermosPayload() error {
return j.task.BuildThermosPayload()
}
func (j *JobUpdate) PartitionPolicy(reschedule bool, delay int64) *JobUpdate {
j.task.PartitionPolicy(aurora.PartitionPolicy{
Reschedule: reschedule,
DelaySecs: &delay,
})
return j
}

View file

@ -14,65 +14,73 @@
package realis
type Logger interface {
type logger interface {
Println(v ...interface{})
Printf(format string, v ...interface{})
Print(v ...interface{})
}
// NoopLogger is a logger that can be attached to the client which will not print anything.
type NoopLogger struct{}
// Printf is a NOOP function here.
func (NoopLogger) Printf(format string, a ...interface{}) {}
// Print is a NOOP function here.
func (NoopLogger) Print(a ...interface{}) {}
// Println is a NOOP function here.
func (NoopLogger) Println(a ...interface{}) {}
// LevelLogger is a logger that can be configured to output different levels of information: Debug and Trace.
// Trace should only be enabled when very in depth information about the sequence of events a function took is needed.
type LevelLogger struct {
Logger
logger
debug bool
trace bool
}
// EnableDebug enables debug level logging for the LevelLogger
func (l *LevelLogger) EnableDebug(enable bool) {
l.debug = enable
}
// EnableTrace enables trace level logging for the LevelLogger
func (l *LevelLogger) EnableTrace(enable bool) {
l.trace = enable
}
func (l LevelLogger) DebugPrintf(format string, a ...interface{}) {
func (l LevelLogger) debugPrintf(format string, a ...interface{}) {
if l.debug {
l.Printf("[DEBUG] "+format, a...)
}
}
func (l LevelLogger) DebugPrint(a ...interface{}) {
func (l LevelLogger) debugPrint(a ...interface{}) {
if l.debug {
l.Print(append([]interface{}{"[DEBUG] "}, a...)...)
}
}
func (l LevelLogger) DebugPrintln(a ...interface{}) {
func (l LevelLogger) debugPrintln(a ...interface{}) {
if l.debug {
l.Println(append([]interface{}{"[DEBUG] "}, a...)...)
}
}
func (l LevelLogger) TracePrintf(format string, a ...interface{}) {
func (l LevelLogger) tracePrintf(format string, a ...interface{}) {
if l.trace {
l.Printf("[TRACE] "+format, a...)
}
}
func (l LevelLogger) TracePrint(a ...interface{}) {
func (l LevelLogger) tracePrint(a ...interface{}) {
if l.trace {
l.Print(append([]interface{}{"[TRACE] "}, a...)...)
}
}
func (l LevelLogger) TracePrintln(a ...interface{}) {
func (l LevelLogger) tracePrintln(a ...interface{}) {
if l.trace {
l.Println(append([]interface{}{"[TRACE] "}, a...)...)
}

View file

@ -12,48 +12,46 @@
* limitations under the License.
*/
// Collection of monitors to create synchronicity
package realis
import (
"time"
"github.com/aurora-scheduler/gorealis/v2/gen-go/apache/aurora"
"github.com/paypal/gorealis/gen-go/apache/aurora"
"github.com/pkg/errors"
)
// MonitorJobUpdate polls the scheduler every certain amount of time to see if the update has succeeded.
// If the update entered a terminal update state but it is not ROLLED_FORWARD, this function will return an error.
func (c *Client) MonitorJobUpdate(updateKey aurora.JobUpdateKey, interval, timeout time.Duration) (bool, error) {
if interval < 1*time.Second {
interval = interval * time.Second
}
// Monitor is a wrapper for the Realis client which allows us to have functions
// with the same name for Monitoring purposes.
// TODO(rdelvalle): Deprecate monitors and instead add prefix Monitor to
// all functions in this file like it is done in V2.
type Monitor struct {
Client Realis
}
if timeout < 1*time.Second {
timeout = timeout * time.Second
// JobUpdate polls the scheduler every certain amount of time to see if the update has entered a terminal state.
func (m *Monitor) JobUpdate(
updateKey aurora.JobUpdateKey,
interval int,
timeout int) (bool, error) {
updateQ := aurora.JobUpdateQuery{
Key: &updateKey,
Limit: 1,
UpdateStatuses: TerminalUpdateStates(),
}
updateSummaries, err := c.MonitorJobUpdateQuery(
aurora.JobUpdateQuery{
Key: &updateKey,
Limit: 1,
UpdateStatuses: []aurora.JobUpdateStatus{
aurora.JobUpdateStatus_ROLLED_FORWARD,
aurora.JobUpdateStatus_ROLLED_BACK,
aurora.JobUpdateStatus_ABORTED,
aurora.JobUpdateStatus_ERROR,
aurora.JobUpdateStatus_FAILED,
},
},
interval,
timeout)
updateSummaries, err := m.JobUpdateQuery(
updateQ,
time.Duration(interval)*time.Second,
time.Duration(timeout)*time.Second)
status := updateSummaries[0].State.Status
if err != nil {
return false, err
}
status := updateSummaries[0].State.Status
c.RealisConfig().logger.Printf("job update status: %v\n", status)
m.Client.RealisConfig().logger.Printf("job update status: %v\n", status)
// Rolled forward is the only state in which an update has been successfully updated
// if we encounter an inactive state and it is not at rolled forward, update failed
@ -70,41 +68,22 @@ func (c *Client) MonitorJobUpdate(updateKey aurora.JobUpdateKey, interval, timeo
}
}
// MonitorJobUpdateStatus polls the scheduler for information about an update until the update enters one of the
// desired states or until the function times out.
func (c *Client) MonitorJobUpdateStatus(updateKey aurora.JobUpdateKey,
// JobUpdateStatus polls the scheduler every certain amount of time to see if the update has entered a specified state.
func (m *Monitor) JobUpdateStatus(updateKey aurora.JobUpdateKey,
desiredStatuses []aurora.JobUpdateStatus,
interval, timeout time.Duration) (aurora.JobUpdateStatus, error) {
if len(desiredStatuses) == 0 {
return aurora.JobUpdateStatus(-1), errors.New("no desired statuses provided")
}
// Make deep local copy to avoid side effects from job key being manipulated externally.
updateKeyLocal := &aurora.JobUpdateKey{
Job: &aurora.JobKey{
Role: updateKey.Job.GetRole(),
Environment: updateKey.Job.GetEnvironment(),
Name: updateKey.Job.GetName(),
},
ID: updateKey.GetID(),
}
updateQ := aurora.JobUpdateQuery{
Key: updateKeyLocal,
Key: &updateKey,
Limit: 1,
UpdateStatuses: desiredStatuses,
}
summary, err := m.JobUpdateQuery(updateQ, interval, timeout)
summary, err := c.MonitorJobUpdateQuery(updateQ, interval, timeout)
if len(summary) > 0 {
return summary[0].State.Status, err
}
return aurora.JobUpdateStatus(-1), err
return summary[0].State.Status, err
}
func (c *Client) MonitorJobUpdateQuery(
// JobUpdateQuery polls the scheduler every certain amount of time to see if the query call returns any results.
func (m *Monitor) JobUpdateQuery(
updateQuery aurora.JobUpdateQuery,
interval time.Duration,
timeout time.Duration) ([]*aurora.JobUpdateSummary, error) {
@ -113,16 +92,20 @@ func (c *Client) MonitorJobUpdateQuery(
defer ticker.Stop()
timer := time.NewTimer(timeout)
defer timer.Stop()
var cliErr error
var respDetail *aurora.Response
for {
select {
case <-ticker.C:
updateSummaryResults, cliErr := c.GetJobUpdateSummaries(&updateQuery)
respDetail, cliErr = m.Client.GetJobUpdateSummaries(&updateQuery)
if cliErr != nil {
return nil, cliErr
}
if len(updateSummaryResults.GetUpdateSummaries()) >= 1 {
return updateSummaryResults.GetUpdateSummaries(), nil
updateSummaries := respDetail.Result_.GetJobUpdateSummariesResult_.UpdateSummaries
if len(updateSummaries) >= 1 {
return updateSummaries, nil
}
case <-timer.C:
@ -131,37 +114,104 @@ func (c *Client) MonitorJobUpdateQuery(
}
}
// Monitor a AuroraJob until all instances enter one of the LiveStates
func (c *Client) MonitorInstances(key aurora.JobKey, instances int32, interval, timeout time.Duration) (bool, error) {
return c.MonitorScheduleStatus(key, instances, aurora.LIVE_STATES, interval, timeout)
// AutoPaused monitor is a special monitor for auto pause enabled batch updates. This monitor ensures that the update
// being monitored is capable of auto pausing and has auto pausing enabled. After verifying this information,
// the monitor watches for the job to enter the ROLL_FORWARD_PAUSED state and calculates the current batch
// the update is in using information from the update configuration.
func (m *Monitor) AutoPausedUpdateMonitor(key aurora.JobUpdateKey, interval, timeout time.Duration) (int, error) {
key.Job = &aurora.JobKey{
Role: key.Job.Role,
Environment: key.Job.Environment,
Name: key.Job.Name,
}
query := aurora.JobUpdateQuery{
UpdateStatuses: aurora.ACTIVE_JOB_UPDATE_STATES,
Limit: 1,
Key: &key,
}
response, err := m.Client.JobUpdateDetails(query)
if err != nil {
return -1, errors.Wrap(err, "unable to get information about update")
}
// TODO (rdelvalle): check for possible nil values when going down the list of structs
updateDetails := response.Result_.GetJobUpdateDetailsResult_.DetailsList
if len(updateDetails) == 0 {
return -1, errors.Errorf("details for update could not be found")
}
updateStrategy := updateDetails[0].Update.Instructions.Settings.UpdateStrategy
var batchSizes []int32
switch {
case updateStrategy.IsSetVarBatchStrategy():
batchSizes = updateStrategy.VarBatchStrategy.GroupSizes
if !updateStrategy.VarBatchStrategy.AutopauseAfterBatch {
return -1, errors.Errorf("update does not have auto pause enabled")
}
case updateStrategy.IsSetBatchStrategy():
batchSizes = []int32{updateStrategy.BatchStrategy.GroupSize}
if !updateStrategy.BatchStrategy.AutopauseAfterBatch {
return -1, errors.Errorf("update does not have auto pause enabled")
}
default:
return -1, errors.Errorf("update is not using a batch update strategy")
}
query.UpdateStatuses = append(TerminalUpdateStates(), aurora.JobUpdateStatus_ROLL_FORWARD_PAUSED)
summary, err := m.JobUpdateQuery(query, interval, timeout)
if err != nil {
return -1, err
}
if summary[0].State.Status != aurora.JobUpdateStatus_ROLL_FORWARD_PAUSED {
return -1, errors.Errorf("update is in a terminal state %v", summary[0].State.Status)
}
updatingInstances := make(map[int32]struct{})
for _, e := range updateDetails[0].InstanceEvents {
// We only care about INSTANCE_UPDATING actions because we only care that they've been attempted
if e != nil && e.GetAction() == aurora.JobUpdateAction_INSTANCE_UPDATING {
updatingInstances[e.GetInstanceId()] = struct{}{}
}
}
return calculateCurrentBatch(int32(len(updatingInstances)), batchSizes), nil
}
// Monitor a AuroraJob until all instances enter a desired status.
// Monitor a Job until all instances enter one of the LIVE_STATES
func (m *Monitor) Instances(key *aurora.JobKey, instances int32, interval, timeout int) (bool, error) {
return m.ScheduleStatus(key, instances, LiveStates, interval, timeout)
}
// ScheduleStatus will monitor a Job until all instances enter a desired status.
// Defaults sets of desired statuses provided by the thrift API include:
// ActiveStates, SlaveAssignedStates, LiveStates, and TerminalStates
func (c *Client) MonitorScheduleStatus(key aurora.JobKey,
// ACTIVE_STATES, SLAVE_ASSIGNED_STATES, LIVE_STATES, and TERMINAL_STATES
func (m *Monitor) ScheduleStatus(
key *aurora.JobKey,
instanceCount int32,
desiredStatuses []aurora.ScheduleStatus,
interval, timeout time.Duration) (bool, error) {
if interval < 1*time.Second {
interval = interval * time.Second
}
desiredStatuses map[aurora.ScheduleStatus]bool,
interval int,
timeout int) (bool, error) {
if timeout < 1*time.Second {
timeout = timeout * time.Second
}
ticker := time.NewTicker(interval)
ticker := time.NewTicker(time.Second * time.Duration(interval))
defer ticker.Stop()
timer := time.NewTimer(timeout)
timer := time.NewTimer(time.Second * time.Duration(timeout))
defer timer.Stop()
wantedStatuses := make([]aurora.ScheduleStatus, 0)
for status := range desiredStatuses {
wantedStatuses = append(wantedStatuses, status)
}
for {
select {
case <-ticker.C:
// Query Aurora for the state of the job key ever interval
instCount, cliErr := c.GetInstanceIds(key, desiredStatuses)
instCount, cliErr := m.Client.GetInstanceIds(key, wantedStatuses)
if cliErr != nil {
return false, errors.Wrap(cliErr, "Unable to communicate with Aurora")
}
@ -171,23 +221,18 @@ func (c *Client) MonitorScheduleStatus(key aurora.JobKey,
case <-timer.C:
// If the timer runs out, return a timeout error to user
return false, newTimedoutError(errors.New("schedule status monitor timedout"))
return false, newTimedoutError(errors.New("schedule status monitor timed out"))
}
}
}
// Monitor host status until all hosts match the status provided. Returns a map where the value is true if the host
// HostMaintenance will monitor host status until all hosts match the status provided.
// Returns a map where the value is true if the host
// is in one of the desired mode(s) or false if it is not as of the time when the monitor exited.
func (c *Client) MonitorHostMaintenance(hosts []string,
func (m *Monitor) HostMaintenance(
hosts []string,
modes []aurora.MaintenanceMode,
interval, timeout time.Duration) (map[string]bool, error) {
if interval < 1*time.Second {
interval = interval * time.Second
}
if timeout < 1*time.Second {
timeout = timeout * time.Second
}
interval, timeout int) (map[string]bool, error) {
// Transform modes to monitor for into a set for easy lookup
desiredMode := make(map[aurora.MaintenanceMode]struct{})
@ -196,7 +241,8 @@ func (c *Client) MonitorHostMaintenance(hosts []string,
}
// Turn slice into a host set to eliminate duplicates.
// We also can't use a simple count because multiple modes means we can have multiple matches for a single host.
// We also can't use a simple count because multiple modes means
// we can have multiple matches for a single host.
// I.e. host A transitions from ACTIVE to DRAINING to DRAINED while monitored
remainingHosts := make(map[string]struct{})
for _, host := range hosts {
@ -205,16 +251,16 @@ func (c *Client) MonitorHostMaintenance(hosts []string,
hostResult := make(map[string]bool)
ticker := time.NewTicker(interval)
ticker := time.NewTicker(time.Second * time.Duration(interval))
defer ticker.Stop()
timer := time.NewTimer(timeout)
timer := time.NewTimer(time.Second * time.Duration(timeout))
defer timer.Stop()
for {
select {
case <-ticker.C:
// Client call has multiple retries internally
result, err := c.MaintenanceStatus(hosts...)
_, result, err := m.Client.MaintenanceStatus(hosts...)
if err != nil {
// Error is either a payload error or a severe connection error
for host := range remainingHosts {
@ -240,73 +286,7 @@ func (c *Client) MonitorHostMaintenance(hosts []string,
hostResult[host] = false
}
return hostResult, newTimedoutError(errors.New("host maintenance monitor timedout"))
return hostResult, newTimedoutError(errors.New("host maintenance monitor timed out"))
}
}
}
// MonitorAutoPausedUpdate is a special monitor for auto pause enabled batch updates. This monitor ensures that the update
// being monitored is capable of auto pausing and has auto pausing enabled. After verifying this information,
// the monitor watches for the job to enter the ROLL_FORWARD_PAUSED state and calculates the current batch
// the update is in using information from the update configuration.
func (c *Client) MonitorAutoPausedUpdate(key aurora.JobUpdateKey, interval, timeout time.Duration) (int, error) {
key.Job = &aurora.JobKey{
Role: key.Job.Role,
Environment: key.Job.Environment,
Name: key.Job.Name,
}
query := aurora.JobUpdateQuery{
UpdateStatuses: aurora.ACTIVE_JOB_UPDATE_STATES,
Limit: 1,
Key: &key,
}
updateDetails, err := c.JobUpdateDetails(query)
if err != nil {
return -1, errors.Wrap(err, "unable to get information about update")
}
if len(updateDetails) == 0 {
return -1, errors.Errorf("details for update could not be found")
}
updateStrategy := updateDetails[0].Update.Instructions.Settings.UpdateStrategy
var batchSizes []int32
switch {
case updateStrategy.IsSetVarBatchStrategy():
batchSizes = updateStrategy.VarBatchStrategy.GroupSizes
if !updateStrategy.VarBatchStrategy.AutopauseAfterBatch {
return -1, errors.Errorf("update does not have auto pause enabled")
}
case updateStrategy.IsSetBatchStrategy():
batchSizes = []int32{updateStrategy.BatchStrategy.GroupSize}
if !updateStrategy.BatchStrategy.AutopauseAfterBatch {
return -1, errors.Errorf("update does not have auto pause enabled")
}
default:
return -1, errors.Errorf("update is not using a batch update strategy")
}
query.UpdateStatuses = append(TerminalUpdateStates(), aurora.JobUpdateStatus_ROLL_FORWARD_PAUSED)
summary, err := c.MonitorJobUpdateQuery(query, interval, timeout)
if err != nil {
return -1, err
}
// Summary 0 is assumed to exist because MonitorJobUpdateQuery will return an error if there is no summaries
if !(summary[0].State.Status == aurora.JobUpdateStatus_ROLL_FORWARD_PAUSED ||
summary[0].State.Status == aurora.JobUpdateStatus_ROLLED_FORWARD) {
return -1, errors.Errorf("update is in a terminal state %v", summary[0].State.Status)
}
updatingInstances := make(map[int32]struct{})
for _, e := range updateDetails[0].InstanceEvents {
// We only care about INSTANCE_UPDATING actions because we only care that they've been attempted
if e != nil && e.GetAction() == aurora.JobUpdateAction_INSTANCE_UPDATING {
updatingInstances[e.GetInstanceId()] = struct{}{}
}
}
return calculateCurrentBatch(int32(len(updatingInstances)), batchSizes), nil
}

434
offer.go
View file

@ -1,434 +0,0 @@
/**
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package realis
import (
"bytes"
"crypto/tls"
"encoding/json"
"fmt"
"net/http"
"strings"
"github.com/aurora-scheduler/gorealis/v2/gen-go/apache/aurora"
)
// Offers on [aurora-scheduler]/offers endpoint
type Offer struct {
ID struct {
Value string `json:"value"`
} `json:"id"`
FrameworkID struct {
Value string `json:"value"`
} `json:"framework_id"`
AgentID struct {
Value string `json:"value"`
} `json:"agent_id"`
Hostname string `json:"hostname"`
URL struct {
Scheme string `json:"scheme"`
Address struct {
Hostname string `json:"hostname"`
IP string `json:"ip"`
Port int `json:"port"`
} `json:"address"`
Path string `json:"path"`
Query []interface{} `json:"query"`
} `json:"url"`
Resources []struct {
Name string `json:"name"`
Type string `json:"type"`
Ranges struct {
Range []struct {
Begin int `json:"begin"`
End int `json:"end"`
} `json:"range"`
} `json:"ranges,omitempty"`
Role string `json:"role"`
Reservations []interface{} `json:"reservations"`
Scalar struct {
Value float64 `json:"value"`
} `json:"scalar,omitempty"`
} `json:"resources"`
Attributes []struct {
Name string `json:"name"`
Type string `json:"type"`
Text struct {
Value string `json:"value"`
} `json:"text"`
} `json:"attributes"`
ExecutorIds []struct {
Value string `json:"value"`
} `json:"executor_ids"`
}
// hosts on [aurora-scheduler]/maintenance endpoint
type MaintenanceList struct {
Drained []string `json:"DRAINED"`
Scheduled []string `json:"SCHEDULED"`
Draining map[string][]string `json:"DRAINING"`
}
type OfferCount map[float64]int64
type OfferGroupReport map[string]OfferCount
type OfferReport map[string]OfferGroupReport
// MaintenanceHosts list all the hosts under maintenance
func (c *Client) MaintenanceHosts() ([]string, error) {
tr := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: c.config.insecureSkipVerify},
}
request := &http.Client{Transport: tr}
resp, err := request.Get(fmt.Sprintf("%s/maintenance", c.GetSchedulerURL()))
if err != nil {
return nil, err
}
defer resp.Body.Close()
buf := new(bytes.Buffer)
if _, err := buf.ReadFrom(resp.Body); err != nil {
return nil, err
}
var list MaintenanceList
if err := json.Unmarshal(buf.Bytes(), &list); err != nil {
return nil, err
}
hosts := append(list.Drained, list.Scheduled...)
for drainingHost := range list.Draining {
hosts = append(hosts, drainingHost)
}
return hosts, nil
}
// Offers pulls data from /offers endpoint
func (c *Client) Offers() ([]Offer, error) {
tr := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: c.config.insecureSkipVerify},
}
request := &http.Client{Transport: tr}
resp, err := request.Get(fmt.Sprintf("%s/offers", c.GetSchedulerURL()))
if err != nil {
return []Offer{}, err
}
defer resp.Body.Close()
buf := new(bytes.Buffer)
if _, err := buf.ReadFrom(resp.Body); err != nil {
return nil, err
}
var offers []Offer
if err := json.Unmarshal(buf.Bytes(), &offers); err != nil {
return []Offer{}, err
}
return offers, nil
}
// AvailOfferReport returns a detailed summary of offers available for use.
// For example, 2 nodes offer 32 cpus and 10 nodes offer 1 cpus.
func (c *Client) AvailOfferReport() (OfferReport, error) {
maintHosts, err := c.MaintenanceHosts()
if err != nil {
return nil, err
}
maintHostSet := map[string]bool{}
for _, h := range maintHosts {
maintHostSet[h] = true
}
// Get a list of offers
offers, err := c.Offers()
if err != nil {
return nil, err
}
report := OfferReport{}
for _, o := range offers {
if maintHostSet[o.Hostname] {
continue
}
group := "non-dedicated"
for _, a := range o.Attributes {
if a.Name == "dedicated" {
group = a.Text.Value
break
}
}
if _, ok := report[group]; !ok {
report[group] = map[string]OfferCount{}
}
for _, r := range o.Resources {
if _, ok := report[group][r.Name]; !ok {
report[group][r.Name] = OfferCount{}
}
val := 0.0
switch r.Type {
case "SCALAR":
val = r.Scalar.Value
case "RANGES":
for _, pr := range r.Ranges.Range {
val += float64(pr.End - pr.Begin + 1)
}
default:
return nil, fmt.Errorf("%s is not supported", r.Type)
}
report[group][r.Name][val]++
}
}
return report, nil
}
// FitTasks computes the number tasks can be fit in a list of offer
func (c *Client) FitTasks(taskConfig *aurora.TaskConfig, offers []Offer) (int64, error) {
// count the number of tasks per limit contraint: limit.name -> limit.value -> count
limitCounts := map[string]map[string]int64{}
for _, c := range taskConfig.Constraints {
if c.Constraint.Limit != nil {
limitCounts[c.Name] = map[string]int64{}
}
}
request := ResourcesToMap(taskConfig.Resources)
// validate resource request
if len(request) == 0 {
return -1, fmt.Errorf("Resource request %v must not be empty", request)
}
isValid := false
for _, resVal := range request {
if resVal > 0 {
isValid = true
break
}
}
if !isValid {
return -1, fmt.Errorf("Resource request %v is not valid", request)
}
// pull the list of hosts under maintenance
maintHosts, err := c.MaintenanceHosts()
if err != nil {
return -1, err
}
maintHostSet := map[string]bool{}
for _, h := range maintHosts {
maintHostSet[h] = true
}
numTasks := int64(0)
for _, o := range offers {
// skip the hosts under maintenance
if maintHostSet[o.Hostname] {
continue
}
numTasksPerOffer := int64(-1)
for resName, resVal := range request {
// skip as we can fit a infinite number of tasks with 0 demand.
if resVal == 0 {
continue
}
avail := 0.0
for _, r := range o.Resources {
if r.Name != resName {
continue
}
switch r.Type {
case "SCALAR":
avail = r.Scalar.Value
case "RANGES":
for _, pr := range r.Ranges.Range {
avail += float64(pr.End - pr.Begin + 1)
}
default:
return -1, fmt.Errorf("%s is not supported", r.Type)
}
}
numTasksPerResource := int64(avail / resVal)
if numTasksPerResource < numTasksPerOffer || numTasksPerOffer < 0 {
numTasksPerOffer = numTasksPerResource
}
}
numTasks += fitConstraints(taskConfig, &o, limitCounts, numTasksPerOffer)
}
return numTasks, nil
}
func fitConstraints(taskConfig *aurora.TaskConfig,
offer *Offer,
limitCounts map[string]map[string]int64,
numTasksPerOffer int64) int64 {
// check dedicated attributes vs. constraints
if !isDedicated(offer, taskConfig.Job.Role, taskConfig.Constraints) {
return 0
}
limitConstraints := []*aurora.Constraint{}
for _, c := range taskConfig.Constraints {
// look for corresponding attribute
attFound := false
for _, a := range offer.Attributes {
if a.Name == c.Name {
attFound = true
}
}
// constraint not found in offer's attributes
if !attFound {
return 0
}
if c.Constraint.Value != nil && !valueConstraint(offer, c) {
// value constraint is not satisfied
return 0
} else if c.Constraint.Limit != nil {
limitConstraints = append(limitConstraints, c)
limit := limitConstraint(offer, c, limitCounts)
if numTasksPerOffer > limit && limit >= 0 {
numTasksPerOffer = limit
}
}
}
// update limitCounts
for _, c := range limitConstraints {
for _, a := range offer.Attributes {
if a.Name == c.Name {
limitCounts[a.Name][a.Text.Value] += numTasksPerOffer
}
}
}
return numTasksPerOffer
}
func isDedicated(offer *Offer, role string, constraints []*aurora.Constraint) bool {
// get all dedicated attributes of an offer
dedicatedAtts := map[string]bool{}
for _, a := range offer.Attributes {
if a.Name == "dedicated" {
dedicatedAtts[a.Text.Value] = true
}
}
if len(dedicatedAtts) == 0 {
return true
}
// check if constraints are matching dedicated attributes
matched := false
for _, c := range constraints {
if c.Name == "dedicated" && c.Constraint.Value != nil {
found := false
for _, v := range c.Constraint.Value.Values {
if dedicatedAtts[v] && strings.HasPrefix(v, fmt.Sprintf("%s/", role)) {
found = true
break
}
}
if found {
matched = true
} else {
return false
}
}
}
return matched
}
// valueConstraint checks Value Contraints of task if the are matched by the offer.
// more details can be found here https://aurora.apache.org/documentation/latest/features/constraints/
func valueConstraint(offer *Offer, constraint *aurora.Constraint) bool {
matched := false
for _, a := range offer.Attributes {
if a.Name == constraint.Name {
for _, v := range constraint.Constraint.Value.Values {
matched = (a.Text.Value == v && !constraint.Constraint.Value.Negated) ||
(a.Text.Value != v && constraint.Constraint.Value.Negated)
if matched {
break
}
}
if matched {
break
}
}
}
return matched
}
// limitConstraint limits the number of pods on a group which has the same attribute.
// more details can be found here https://aurora.apache.org/documentation/latest/features/constraints/
func limitConstraint(offer *Offer, constraint *aurora.Constraint, limitCounts map[string]map[string]int64) int64 {
limit := int64(-1)
for _, a := range offer.Attributes {
// limit constraint found
if a.Name == constraint.Name {
curr := limitCounts[a.Name][a.Text.Value]
currLimit := int64(constraint.Constraint.Limit.Limit)
if curr >= currLimit {
return 0
}
if currLimit-curr < limit || limit < 0 {
limit = currLimit - curr
}
}
}
return limit
}

1032
realis.go

File diff suppressed because it is too large Load diff

View file

@ -1,293 +1,274 @@
/**
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package realis
import (
"context"
"github.com/aurora-scheduler/gorealis/v2/gen-go/apache/aurora"
"github.com/paypal/gorealis/gen-go/apache/aurora"
"github.com/pkg/errors"
)
// TODO(rdelvalle): Consider moving these functions to another interface. It would be a backwards incompatible change,
// but would add safety.
// Set a list of nodes to DRAINING. This means nothing will be able to be scheduled on them and any existing
// tasks will be killed and re-scheduled elsewhere in the cluster. Tasks from DRAINING nodes are not guaranteed
// to return to running unless there is enough capacity in the cluster to run them.
func (c *Client) DrainHosts(hosts ...string) ([]*aurora.HostStatus, error) {
func (r *realisClient) DrainHosts(hosts ...string) (*aurora.Response, *aurora.DrainHostsResult_, error) {
var result *aurora.DrainHostsResult_
if len(hosts) == 0 {
return nil, errors.New("no hosts provided to drain")
return nil, nil, errors.New("no hosts provided to drain")
}
drainList := aurora.NewHosts()
drainList.HostNames = hosts
c.logger.DebugPrintf("DrainHosts Thrift Payload: %v\n", drainList)
r.logger.debugPrintf("DrainHosts Thrift Payload: %v\n", drainList)
resp, retryErr := c.thriftCallWithRetries(false, func() (*aurora.Response, error) {
return c.adminClient.DrainHosts(context.TODO(), drainList)
},
nil,
)
resp, retryErr := r.thriftCallWithRetries(
false,
func() (*aurora.Response, error) {
return r.adminClient.DrainHosts(context.TODO(), drainList)
})
if retryErr != nil {
return nil, errors.Wrap(retryErr, "unable to recover connection")
return resp, result, errors.Wrap(retryErr, "Unable to recover connection")
}
if resp == nil || resp.GetResult_() == nil || resp.GetResult_().GetDrainHostsResult_() == nil {
return nil, errors.New("unexpected response from scheduler")
if resp.GetResult_() != nil {
result = resp.GetResult_().GetDrainHostsResult_()
}
return resp.GetResult_().GetDrainHostsResult_().GetStatuses(), nil
return resp, result, nil
}
// Start SLA Aware Drain.
// defaultSlaPolicy is the fallback SlaPolicy to use if a task does not have an SlaPolicy.
// After timeoutSecs, tasks will be forcefully drained without checking SLA.
func (c *Client) SLADrainHosts(policy *aurora.SlaPolicy, timeout int64, hosts ...string) ([]*aurora.HostStatus, error) {
func (r *realisClient) SLADrainHosts(
policy *aurora.SlaPolicy,
timeout int64,
hosts ...string) (*aurora.DrainHostsResult_, error) {
var result *aurora.DrainHostsResult_
if len(hosts) == 0 {
return nil, errors.New("no hosts provided to drain")
}
if policy == nil || policy.CountSetFieldsSlaPolicy() == 0 {
policy = &defaultSlaPolicy
c.logger.Printf("Warning: start draining with default sla policy %v", policy)
}
if timeout < 0 {
c.logger.Printf("Warning: timeout %d secs is invalid, draining with default timeout %d secs",
timeout,
defaultSlaDrainTimeoutSecs)
timeout = defaultSlaDrainTimeoutSecs
}
drainList := aurora.NewHosts()
drainList.HostNames = hosts
c.logger.DebugPrintf("SLADrainHosts Thrift Payload: %v\n", drainList)
r.logger.debugPrintf("SLADrainHosts Thrift Payload: %v\n", drainList)
resp, retryErr := c.thriftCallWithRetries(false, func() (*aurora.Response, error) {
return c.adminClient.SlaDrainHosts(context.TODO(), drainList, policy, timeout)
},
nil,
)
resp, retryErr := r.thriftCallWithRetries(
false,
func() (*aurora.Response, error) {
return r.adminClient.SlaDrainHosts(context.TODO(), drainList, policy, timeout)
})
if retryErr != nil {
return nil, errors.Wrap(retryErr, "unable to recover connection")
return result, errors.Wrap(retryErr, "Unable to recover connection")
}
if resp == nil || resp.GetResult_() == nil || resp.GetResult_().GetDrainHostsResult_() == nil {
return nil, errors.New("unexpected response from scheduler")
if resp.GetResult_() != nil {
result = resp.GetResult_().GetDrainHostsResult_()
}
return resp.GetResult_().GetDrainHostsResult_().GetStatuses(), nil
return result, nil
}
func (c *Client) StartMaintenance(hosts ...string) ([]*aurora.HostStatus, error) {
func (r *realisClient) StartMaintenance(hosts ...string) (*aurora.Response, *aurora.StartMaintenanceResult_, error) {
var result *aurora.StartMaintenanceResult_
if len(hosts) == 0 {
return nil, errors.New("no hosts provided to start maintenance on")
return nil, nil, errors.New("no hosts provided to start maintenance on")
}
hostList := aurora.NewHosts()
hostList.HostNames = hosts
c.logger.DebugPrintf("StartMaintenance Thrift Payload: %v\n", hostList)
r.logger.debugPrintf("StartMaintenance Thrift Payload: %v\n", hostList)
resp, retryErr := c.thriftCallWithRetries(false, func() (*aurora.Response, error) {
return c.adminClient.StartMaintenance(context.TODO(), hostList)
},
nil,
)
resp, retryErr := r.thriftCallWithRetries(
false,
func() (*aurora.Response, error) {
return r.adminClient.StartMaintenance(context.TODO(), hostList)
})
if retryErr != nil {
return nil, errors.Wrap(retryErr, "unable to recover connection")
return resp, result, errors.Wrap(retryErr, "Unable to recover connection")
}
if resp == nil || resp.GetResult_() == nil || resp.GetResult_().GetStartMaintenanceResult_() == nil {
return nil, errors.New("unexpected response from scheduler")
if resp.GetResult_() != nil {
result = resp.GetResult_().GetStartMaintenanceResult_()
}
return resp.GetResult_().GetStartMaintenanceResult_().GetStatuses(), nil
return resp, result, nil
}
func (c *Client) EndMaintenance(hosts ...string) ([]*aurora.HostStatus, error) {
func (r *realisClient) EndMaintenance(hosts ...string) (*aurora.Response, *aurora.EndMaintenanceResult_, error) {
var result *aurora.EndMaintenanceResult_
if len(hosts) == 0 {
return nil, errors.New("no hosts provided to end maintenance on")
return nil, nil, errors.New("no hosts provided to end maintenance on")
}
hostList := aurora.NewHosts()
hostList.HostNames = hosts
c.logger.DebugPrintf("EndMaintenance Thrift Payload: %v\n", hostList)
r.logger.debugPrintf("EndMaintenance Thrift Payload: %v\n", hostList)
resp, retryErr := c.thriftCallWithRetries(false, func() (*aurora.Response, error) {
return c.adminClient.EndMaintenance(context.TODO(), hostList)
},
nil,
)
resp, retryErr := r.thriftCallWithRetries(
false,
func() (*aurora.Response, error) {
return r.adminClient.EndMaintenance(context.TODO(), hostList)
})
if retryErr != nil {
return nil, errors.Wrap(retryErr, "unable to recover connection")
return resp, result, errors.Wrap(retryErr, "Unable to recover connection")
}
if resp == nil || resp.GetResult_() == nil || resp.GetResult_().GetEndMaintenanceResult_() == nil {
return nil, errors.New("unexpected response from scheduler")
if resp.GetResult_() != nil {
result = resp.GetResult_().GetEndMaintenanceResult_()
}
return resp.GetResult_().GetEndMaintenanceResult_().GetStatuses(), nil
return resp, result, nil
}
func (c *Client) MaintenanceStatus(hosts ...string) (*aurora.MaintenanceStatusResult_, error) {
func (r *realisClient) MaintenanceStatus(hosts ...string) (*aurora.Response, *aurora.MaintenanceStatusResult_, error) {
var result *aurora.MaintenanceStatusResult_
if len(hosts) == 0 {
return nil, errors.New("no hosts provided to get maintenance status from")
return nil, nil, errors.New("no hosts provided to get maintenance status from")
}
hostList := aurora.NewHosts()
hostList.HostNames = hosts
c.logger.DebugPrintf("MaintenanceStatus Thrift Payload: %v\n", hostList)
r.logger.debugPrintf("MaintenanceStatus Thrift Payload: %v\n", hostList)
// Make thrift call. If we encounter an error sending the call, attempt to reconnect
// and continue trying to resend command until we run out of retries.
resp, retryErr := c.thriftCallWithRetries(false, func() (*aurora.Response, error) {
return c.adminClient.MaintenanceStatus(context.TODO(), hostList)
},
nil,
)
resp, retryErr := r.thriftCallWithRetries(
false,
func() (*aurora.Response, error) {
return r.adminClient.MaintenanceStatus(context.TODO(), hostList)
})
if retryErr != nil {
return nil, errors.Wrap(retryErr, "unable to recover connection")
}
if resp == nil || resp.GetResult_() == nil {
return nil, errors.New("unexpected response from scheduler")
return resp, result, errors.Wrap(retryErr, "Unable to recover connection")
}
return resp.GetResult_().GetMaintenanceStatusResult_(), nil
if resp.GetResult_() != nil {
result = resp.GetResult_().GetMaintenanceStatusResult_()
}
return resp, result, nil
}
// SetQuota sets a quota aggregate for the given role
// TODO(zircote) Currently investigating an error that is returned from thrift calls that include resources for `NamedPort` and `NumGpu`
func (c *Client) SetQuota(role string, cpu *float64, ramMb *int64, diskMb *int64) error {
ramResource := aurora.NewResource()
ramResource.RamMb = ramMb
cpuResource := aurora.NewResource()
cpuResource.NumCpus = cpu
diskResource := aurora.NewResource()
diskResource.DiskMb = diskMb
// TODO(zircote) Currently investigating an error that is returned
// from thrift calls that include resources for `NamedPort` and `NumGpu`
func (r *realisClient) SetQuota(role string, cpu *float64, ramMb *int64, diskMb *int64) (*aurora.Response, error) {
quota := &aurora.ResourceAggregate{
Resources: []*aurora.Resource{{NumCpus: cpu}, {RamMb: ramMb}, {DiskMb: diskMb}},
}
quota := aurora.NewResourceAggregate()
quota.Resources = []*aurora.Resource{ramResource, cpuResource, diskResource}
_, retryErr := c.thriftCallWithRetries(false, func() (*aurora.Response, error) {
return c.adminClient.SetQuota(context.TODO(), role, quota)
},
nil,
)
resp, retryErr := r.thriftCallWithRetries(
false,
func() (*aurora.Response, error) {
return r.adminClient.SetQuota(context.TODO(), role, quota)
})
if retryErr != nil {
return errors.Wrap(retryErr, "unable to set role quota")
return resp, errors.Wrap(retryErr, "Unable to set role quota")
}
return retryErr
return resp, retryErr
}
// GetQuota returns the resource aggregate for the given role
func (c *Client) GetQuota(role string) (*aurora.GetQuotaResult_, error) {
func (r *realisClient) GetQuota(role string) (*aurora.Response, error) {
resp, retryErr := c.thriftCallWithRetries(false, func() (*aurora.Response, error) {
return c.adminClient.GetQuota(context.TODO(), role)
},
nil,
)
resp, retryErr := r.thriftCallWithRetries(
false,
func() (*aurora.Response, error) {
return r.adminClient.GetQuota(context.TODO(), role)
})
if retryErr != nil {
return nil, errors.Wrap(retryErr, "unable to get role quota")
return resp, errors.Wrap(retryErr, "Unable to get role quota")
}
if resp == nil || resp.GetResult_() == nil {
return nil, errors.New("unexpected response from scheduler")
}
return resp.GetResult_().GetGetQuotaResult_(), nil
return resp, retryErr
}
// Force Aurora Scheduler to perform a snapshot and write to Mesos log
func (c *Client) Snapshot() error {
func (r *realisClient) Snapshot() error {
_, retryErr := c.thriftCallWithRetries(false, func() (*aurora.Response, error) {
return c.adminClient.Snapshot(context.TODO())
},
nil,
)
_, retryErr := r.thriftCallWithRetries(
false,
func() (*aurora.Response, error) {
return r.adminClient.Snapshot(context.TODO())
})
if retryErr != nil {
return errors.Wrap(retryErr, "unable to recover connection")
return errors.Wrap(retryErr, "Unable to recover connection")
}
return nil
}
// Force Aurora Scheduler to write backup file to a file in the backup directory
func (c *Client) PerformBackup() error {
func (r *realisClient) PerformBackup() error {
_, retryErr := c.thriftCallWithRetries(false, func() (*aurora.Response, error) {
return c.adminClient.PerformBackup(context.TODO())
},
nil,
)
_, retryErr := r.thriftCallWithRetries(
false,
func() (*aurora.Response, error) {
return r.adminClient.PerformBackup(context.TODO())
})
if retryErr != nil {
return errors.Wrap(retryErr, "unable to recover connection")
return errors.Wrap(retryErr, "Unable to recover connection")
}
return nil
}
// Force an Implicit reconciliation between Mesos and Aurora
func (c *Client) ForceImplicitTaskReconciliation() error {
func (r *realisClient) ForceImplicitTaskReconciliation() error {
_, retryErr := c.thriftCallWithRetries(false, func() (*aurora.Response, error) {
return c.adminClient.TriggerImplicitTaskReconciliation(context.TODO())
},
nil,
)
_, retryErr := r.thriftCallWithRetries(
false,
func() (*aurora.Response, error) {
return r.adminClient.TriggerImplicitTaskReconciliation(context.TODO())
})
if retryErr != nil {
return errors.Wrap(retryErr, "unable to recover connection")
return errors.Wrap(retryErr, "Unable to recover connection")
}
return nil
}
// Force an Explicit reconciliation between Mesos and Aurora
func (c *Client) ForceExplicitTaskReconciliation(batchSize *int32) error {
func (r *realisClient) ForceExplicitTaskReconciliation(batchSize *int32) error {
if batchSize != nil && *batchSize < 1 {
return errors.New("invalid batch size.")
return errors.New("invalid batch size")
}
settings := aurora.NewExplicitReconciliationSettings()
settings.BatchSize = batchSize
_, retryErr := c.thriftCallWithRetries(false, func() (*aurora.Response, error) {
return c.adminClient.TriggerExplicitTaskReconciliation(context.TODO(), settings)
},
nil,
)
_, retryErr := r.thriftCallWithRetries(false,
func() (*aurora.Response, error) {
return r.adminClient.TriggerExplicitTaskReconciliation(context.TODO(), settings)
})
if retryErr != nil {
return errors.Wrap(retryErr, "unable to recover connection")
return errors.Wrap(retryErr, "Unable to recover connection")
}
return nil

View file

@ -1,181 +0,0 @@
/**
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package realis
import (
"strings"
"time"
"github.com/apache/thrift/lib/go/thrift"
"github.com/aurora-scheduler/gorealis/v2/gen-go/apache/aurora"
)
type clientConfig struct {
username, password string
url string
timeout time.Duration
transportProtocol TransportProtocol
cluster *Cluster
backoff Backoff
transport thrift.TTransport
protoFactory thrift.TProtocolFactory
logger *LevelLogger
insecureSkipVerify bool
certsPath string
clientKey, clientCert string
options []ClientOption
debug bool
trace bool
zkOptions []ZKOpt
failOnPermanentErrors bool
}
var defaultBackoff = Backoff{
Steps: 3,
Duration: 10 * time.Second,
Factor: 5.0,
Jitter: 0.1,
}
var defaultSlaPolicy = aurora.SlaPolicy{
PercentageSlaPolicy: &aurora.PercentageSlaPolicy{
Percentage: 66,
DurationSecs: 300,
},
}
const defaultSlaDrainTimeoutSecs = 900
type TransportProtocol int
const (
unsetProtocol TransportProtocol = iota
jsonProtocol
binaryProtocol
)
type ClientOption func(*clientConfig)
// clientConfig sets for options in clientConfig.
func BasicAuth(username, password string) ClientOption {
return func(config *clientConfig) {
config.username = username
config.password = password
}
}
func SchedulerUrl(url string) ClientOption {
return func(config *clientConfig) {
config.url = url
}
}
func Timeout(timeout time.Duration) ClientOption {
return func(config *clientConfig) {
config.timeout = timeout
}
}
func ZKCluster(cluster *Cluster) ClientOption {
return func(config *clientConfig) {
config.cluster = cluster
}
}
func ZKUrl(url string) ClientOption {
opts := []ZKOpt{ZKEndpoints(strings.Split(url, ",")...), ZKPath("/aurora/scheduler")}
return func(config *clientConfig) {
if config.zkOptions == nil {
config.zkOptions = opts
} else {
config.zkOptions = append(config.zkOptions, opts...)
}
}
}
func ThriftJSON() ClientOption {
return func(config *clientConfig) {
config.transportProtocol = jsonProtocol
}
}
func ThriftBinary() ClientOption {
return func(config *clientConfig) {
config.transportProtocol = binaryProtocol
}
}
func BackOff(b Backoff) ClientOption {
return func(config *clientConfig) {
config.backoff = b
}
}
func InsecureSkipVerify(InsecureSkipVerify bool) ClientOption {
return func(config *clientConfig) {
config.insecureSkipVerify = InsecureSkipVerify
}
}
func CertsPath(certspath string) ClientOption {
return func(config *clientConfig) {
config.certsPath = certspath
}
}
func ClientCerts(clientKey, clientCert string) ClientOption {
return func(config *clientConfig) {
config.clientKey, config.clientCert = clientKey, clientCert
}
}
// Use this option if you'd like to override default settings for connecting to Zookeeper.
// See zk.go for what is possible to set as an option.
func ZookeeperOptions(opts ...ZKOpt) ClientOption {
return func(config *clientConfig) {
config.zkOptions = opts
}
}
// Using the word set to avoid name collision with Interface.
func SetLogger(l Logger) ClientOption {
return func(config *clientConfig) {
config.logger = &LevelLogger{Logger: l}
}
}
// Enable debug statements.
func Debug() ClientOption {
return func(config *clientConfig) {
config.debug = true
}
}
// Enable trace statements.
func Trace() ClientOption {
return func(config *clientConfig) {
config.trace = true
}
}
// FailOnPermanentErrors - If the client encounters a connection error the standard library
// considers permanent, stop retrying and return an error to the user.
func FailOnPermanentErrors() ClientOption {
return func(config *clientConfig) {
config.failOnPermanentErrors = true
}
}

File diff suppressed because it is too large Load diff

View file

@ -1,72 +0,0 @@
/**
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package realis
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGetCACerts(t *testing.T) {
certs, err := GetCerts("./examples/certs")
require.NoError(t, err)
assert.Equal(t, len(certs.Subjects()), 2)
}
func TestAuroraURLValidator(t *testing.T) {
t.Run("badURL", func(t *testing.T) {
url, err := validateAuroraAddress("http://badurl.com/badpath")
assert.Empty(t, url)
assert.Error(t, err)
})
t.Run("URLHttp", func(t *testing.T) {
url, err := validateAuroraAddress("http://goodurl.com:8081/api")
assert.Equal(t, "http://goodurl.com:8081/api", url)
assert.NoError(t, err)
})
t.Run("URLHttps", func(t *testing.T) {
url, err := validateAuroraAddress("https://goodurl.com:8081/api")
assert.Equal(t, "https://goodurl.com:8081/api", url)
assert.NoError(t, err)
})
t.Run("URLNoPath", func(t *testing.T) {
url, err := validateAuroraAddress("http://goodurl.com:8081")
assert.Equal(t, "http://goodurl.com:8081/api", url)
assert.NoError(t, err)
})
t.Run("ipAddrNoPath", func(t *testing.T) {
url, err := validateAuroraAddress("http://192.168.1.33:8081")
assert.Equal(t, "http://192.168.1.33:8081/api", url)
assert.NoError(t, err)
})
t.Run("URLNoProtocol", func(t *testing.T) {
url, err := validateAuroraAddress("goodurl.com:8081/api")
assert.Equal(t, "http://goodurl.com:8081/api", url)
assert.NoError(t, err)
})
t.Run("URLNoProtocolNoPathNoPort", func(t *testing.T) {
url, err := validateAuroraAddress("goodurl.com")
assert.Equal(t, "http://goodurl.com:8081/api", url)
assert.NoError(t, err)
})
}

View file

@ -17,8 +17,9 @@ package response
import (
"bytes"
"errors"
"github.com/aurora-scheduler/gorealis/v2/gen-go/apache/aurora"
"github.com/paypal/gorealis/gen-go/apache/aurora"
)
// Get key from a response created by a StartJobUpdate call
@ -35,11 +36,19 @@ func ScheduleStatusResult(resp *aurora.Response) *aurora.ScheduleStatusResult_ {
}
func JobUpdateSummaries(resp *aurora.Response) []*aurora.JobUpdateSummary {
if resp == nil || resp.GetResult_() == nil || resp.GetResult_().GetGetJobUpdateSummariesResult_() == nil {
return nil
return resp.GetResult_().GetGetJobUpdateSummariesResult_().GetUpdateSummaries()
}
// Deprecated: Replaced by checks done inside of thriftCallHelper
func ResponseCodeCheck(resp *aurora.Response) (*aurora.Response, error) {
if resp == nil {
return resp, errors.New("Response is nil")
}
if resp.GetResponseCode() != aurora.ResponseCode_OK {
return resp, errors.New(CombineMessage(resp))
}
return resp.GetResult_().GetGetJobUpdateSummariesResult_().GetUpdateSummaries()
return resp, nil
}
// Based on aurora client: src/main/python/apache/aurora/client/base.py

198
retry.go
View file

@ -21,8 +21,8 @@ import (
"time"
"github.com/apache/thrift/lib/go/thrift"
"github.com/aurora-scheduler/gorealis/v2/gen-go/apache/aurora"
"github.com/aurora-scheduler/gorealis/v2/response"
"github.com/paypal/gorealis/gen-go/apache/aurora"
"github.com/paypal/gorealis/response"
"github.com/pkg/errors"
)
@ -61,11 +61,10 @@ type ConditionFunc func() (done bool, err error)
//
// If the condition never returns true, ErrWaitTimeout is returned. Errors
// do not cause the function to return.
func ExponentialBackoff(backoff Backoff, logger Logger, condition ConditionFunc) error {
func ExponentialBackoff(backoff Backoff, logger logger, condition ConditionFunc) error {
var err error
var ok bool
var curStep int
duration := backoff.Duration
for curStep = 0; curStep < backoff.Steps; curStep++ {
@ -77,7 +76,8 @@ func ExponentialBackoff(backoff Backoff, logger Logger, condition ConditionFunc)
adjusted = Jitter(duration, backoff.Jitter)
}
logger.Printf("A retryable error occurred during function call, backing off for %v before retrying\n", adjusted)
logger.Printf(
"A retryable error occurred during function call, backing off for %v before retrying\n", adjusted)
time.Sleep(adjusted)
duration = time.Duration(float64(duration) * backoff.Factor)
}
@ -114,23 +114,17 @@ func ExponentialBackoff(backoff Backoff, logger Logger, condition ConditionFunc)
type auroraThriftCall func() (resp *aurora.Response, err error)
// verifyOntimeout defines the type of function that will be used to verify whether a Thirft call to the Scheduler
// made it to the scheduler or not. In general, these types of functions will have to interact with the scheduler
// through the very same Thrift API which previously encountered a time-out from the client.
// This means that the functions themselves should be kept to a minimum number of Thrift calls.
// It should also be noted that this is a best effort mechanism and
// is likely to fail for the same reasons that the original call failed.
type verifyOnTimeout func() (*aurora.Response, bool)
// Duplicates the functionality of ExponentialBackoff but is specifically targeted towards ThriftCalls.
func (c *Client) thriftCallWithRetries(returnOnTimeout bool, thriftCall auroraThriftCall,
verifyOnTimeout verifyOnTimeout) (*aurora.Response, error) {
func (r *realisClient) thriftCallWithRetries(
returnOnTimeout bool,
thriftCall auroraThriftCall) (*aurora.Response, error) {
var resp *aurora.Response
var clientErr error
var curStep int
timeouts := 0
backoff := c.config.backoff
backoff := r.config.backoff
duration := backoff.Duration
for curStep = 0; curStep < backoff.Steps; curStep++ {
@ -142,8 +136,8 @@ func (c *Client) thriftCallWithRetries(returnOnTimeout bool, thriftCall auroraTh
adjusted = Jitter(duration, backoff.Jitter)
}
c.logger.Printf(
"A retryable error occurred during thrift call, backing off for %v before retry %v",
r.logger.Printf(
"A retryable error occurred during thrift call, backing off for %v before retry %v\n",
adjusted,
curStep)
@ -155,104 +149,101 @@ func (c *Client) thriftCallWithRetries(returnOnTimeout bool, thriftCall auroraTh
// Placing this in an anonymous function in order to create a new, short-lived stack allowing unlock
// to be run in case of a panic inside of thriftCall.
func() {
c.lock.Lock()
defer c.lock.Unlock()
r.lock.Lock()
defer r.lock.Unlock()
resp, clientErr = thriftCall()
c.logger.TracePrintf("Aurora Thrift Call ended resp: %v clientErr: %v", resp, clientErr)
r.logger.tracePrintf("Aurora Thrift Call ended resp: %v clientErr: %v\n", resp, clientErr)
}()
// Check if our thrift call is returning an error. This is a retryable event as we don't know
// if it was caused by network issues.
if clientErr != nil {
// Print out the error to the user
c.logger.Printf("Client Error: %v", clientErr)
r.logger.Printf("Client Error: %v\n", clientErr)
temporary, timedout := isConnectionError(clientErr)
if !temporary && c.RealisConfig().failOnPermanentErrors {
return nil, errors.Wrap(clientErr, "permanent connection error")
}
// Determine if error is a temporary URL error by going up the stack
e, ok := clientErr.(thrift.TTransportException)
if ok {
r.logger.debugPrint("Encountered a transport exception")
// There exists a corner case where thrift payload was received by Aurora but
// connection timed out before Aurora was able to reply.
// Users can take special action on a timeout by using IsTimedout and reacting accordingly
// if they have configured the client to return on a timeout.
if timedout && returnOnTimeout {
return resp, newTimedoutError(errors.New("client connection closed before server answer"))
e, ok := e.Err().(*url.Error)
if ok {
// EOF error occurs when the server closes the read buffer of the client. This is common
// when the server is overloaded and should be retried. All other errors that are permanent
// will not be retried.
if e.Err != io.EOF && !e.Temporary() && r.RealisConfig().failOnPermanentErrors {
return nil, errors.Wrap(clientErr, "permanent connection error")
}
// Corner case where thrift payload was received by Aurora but connection timedout before Aurora was
// able to reply. In this case we will return whatever response was received and a TimedOut behaving
// error. Users can take special action on a timeout by using IsTimedout and reacting accordingly.
if e.Timeout() {
timeouts++
r.logger.debugPrintf(
"Client closed connection (timedout) %d times before server responded, "+
"consider increasing connection timeout",
timeouts)
if returnOnTimeout {
return resp, newTimedoutError(errors.New("client connection closed before server answer"))
}
}
}
}
// In the future, reestablish connection should be able to check if it is actually possible
// to make a thrift call to Aurora. For now, a reconnect should always lead to a retry.
// Ignoring error due to the fact that an error should be retried regardless
reestablishErr := c.ReestablishConn()
reestablishErr := r.ReestablishConn()
if reestablishErr != nil {
c.logger.DebugPrintf("error re-establishing connection ", reestablishErr)
r.logger.debugPrintf("error re-establishing connection ", reestablishErr)
}
} else {
// If there was no client error, but the response is nil, something went wrong.
// Ideally, we'll never encounter this but we're placing a safeguard here.
if resp == nil {
return nil, errors.New("response from aurora is nil")
}
// If users did not opt for a return on timeout in order to react to a timedout error,
// attempt to verify that the call made it to the scheduler after the connection was re-established.
if timedout {
timeouts++
c.logger.DebugPrintf(
"Client closed connection %d times before server responded, "+
"consider increasing connection timeout",
timeouts)
// Check Response Code from thrift and make a decision to continue retrying or not.
switch responseCode := resp.GetResponseCode(); responseCode {
// Allow caller to provide a function which checks if the original call was successful before
// it timed out.
if verifyOnTimeout != nil {
if verifyResp, ok := verifyOnTimeout(); ok {
c.logger.Print("verified that the call went through successfully after a client timeout")
// Response here might be different than the original as it is no longer constructed
// by the scheduler but mimicked.
// This is OK since the scheduler is very unlikely to change responses at this point in its
// development cycle but we must be careful to not return an incorrectly constructed response.
return verifyResp, nil
}
}
// If the thrift call succeeded, stop retrying
case aurora.ResponseCode_OK:
return resp, nil
// If the response code is transient, continue retrying
case aurora.ResponseCode_ERROR_TRANSIENT:
r.logger.Println("Aurora replied with Transient error code, retrying")
continue
// Failure scenarios, these indicate a bad payload or a bad config. Stop retrying.
case aurora.ResponseCode_INVALID_REQUEST,
aurora.ResponseCode_ERROR,
aurora.ResponseCode_AUTH_FAILED,
aurora.ResponseCode_JOB_UPDATING_ERROR:
r.logger.Printf("Terminal Response Code %v from Aurora, won't retry\n", resp.GetResponseCode().String())
return resp, errors.New(response.CombineMessage(resp))
// The only case that should fall down to here is a WARNING response code.
// It is currently not used as a response in the scheduler so it is unknown how to handle it.
default:
r.logger.debugPrintf("unhandled response code %v received from Aurora\n", responseCode)
return nil, errors.Errorf("unhandled response code from Aurora %v", responseCode.String())
}
// Retry the thrift payload
continue
}
// If there was no client error, but the response is nil, something went wrong.
// Ideally, we'll never encounter this but we're placing a safeguard here.
if resp == nil {
return nil, errors.New("response from aurora is nil")
}
// Check Response Code from thrift and make a decision to continue retrying or not.
switch responseCode := resp.GetResponseCode(); responseCode {
// If the thrift call succeeded, stop retrying
case aurora.ResponseCode_OK:
return resp, nil
// If the response code is transient, continue retrying
case aurora.ResponseCode_ERROR_TRANSIENT:
c.logger.Println("Aurora replied with Transient error code, retrying")
continue
// Failure scenarios, these indicate a bad payload or a bad clientConfig. Stop retrying.
case aurora.ResponseCode_INVALID_REQUEST,
aurora.ResponseCode_ERROR,
aurora.ResponseCode_AUTH_FAILED,
aurora.ResponseCode_JOB_UPDATING_ERROR:
c.logger.Printf("Terminal Response Code %v from Aurora, won't retry\n", resp.GetResponseCode().String())
return resp, errors.New(response.CombineMessage(resp))
// The only case that should fall down to here is a WARNING response code.
// It is currently not used as a response in the scheduler so it is unknown how to handle it.
default:
c.logger.DebugPrintf("unhandled response code %v received from Aurora\n", responseCode)
return nil, errors.Errorf("unhandled response code from Aurora %v", responseCode.String())
}
}
r.logger.debugPrintf("it took %v retries to complete this operation\n", curStep)
if curStep > 1 {
c.config.logger.Printf("this thrift call was retried %d time(s)", curStep)
r.config.logger.Printf("retried this thrift call %d time(s)", curStep)
}
// Provide more information to the user wherever possible.
@ -262,30 +253,3 @@ func (c *Client) thriftCallWithRetries(returnOnTimeout bool, thriftCall auroraTh
return nil, newRetryError(errors.New("ran out of retries"), curStep)
}
// isConnectionError processes the error received by the client.
// The return values indicate whether this was determined to be a temporary error
// and whether it was determined to be a timeout error
func isConnectionError(err error) (bool, bool) {
// Determine if error is a temporary URL error by going up the stack
transportException, ok := err.(thrift.TTransportException)
if !ok {
return false, false
}
urlError, ok := transportException.Err().(*url.Error)
if !ok {
return false, false
}
// EOF error occurs when the server closes the read buffer of the client. This is common
// when the server is overloaded and we consider it temporary.
// All other which are not temporary as per the member function Temporary(),
// are considered not temporary (permanent).
if urlError.Err != io.EOF && !urlError.Temporary() {
return false, false
}
return true, urlError.Timeout()
}

13
runTests.sh Executable file
View file

@ -0,0 +1,13 @@
#!/bin/bash
docker-compose up -d
# If running docker-compose up gives any error, don't do anything.
if [ $? -ne 0 ]; then
exit
fi
# Since we run our docker compose setup in bridge mode to be able to run on MacOS, we have to launch a Docker container within the bridge network in order to avoid any routing issues.
docker run --rm -t -v $(pwd):/go/src/github.com/paypal/gorealis --network gorealis_aurora_cluster golang:1.10-stretch go test -v github.com/paypal/gorealis $@
docker-compose down

View file

@ -1,4 +1,4 @@
#!/bin/bash
# Since we run our docker compose setup in bridge mode to be able to run on MacOS, we have to launch a Docker container within the bridge network in order to avoid any routing issues.
docker run --rm -t -w /gorealis -v $GOPATH/pkg:/go/pkg -v $(pwd):/gorealis --network gorealis_aurora_cluster golang:1.17-buster go test -v github.com/aurora-scheduler/gorealis/v2 $@
docker run --rm -t -v $(pwd):/go/src/github.com/paypal/gorealis --network gorealis_aurora_cluster golang:1.10-stretch go test -v github.com/paypal/gorealis $@

465
task.go
View file

@ -1,465 +0,0 @@
/**
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package realis
import (
"encoding/json"
"strconv"
"github.com/apache/thrift/lib/go/thrift"
"github.com/aurora-scheduler/gorealis/v2/gen-go/apache/aurora"
)
type ResourceType int
const (
CPU ResourceType = iota
RAM
DISK
GPU
)
const (
dedicated = "dedicated"
portPrefix = "org.apache.aurora.port."
)
type AuroraTask struct {
task *aurora.TaskConfig
resources map[ResourceType]*aurora.Resource
portCount int
thermos *ThermosExecutor
}
func NewTask() *AuroraTask {
numCpus := &aurora.Resource{}
ramMb := &aurora.Resource{}
diskMb := &aurora.Resource{}
numCpus.NumCpus = new(float64)
ramMb.RamMb = new(int64)
diskMb.DiskMb = new(int64)
resources := map[ResourceType]*aurora.Resource{CPU: numCpus, RAM: ramMb, DISK: diskMb}
return &AuroraTask{task: &aurora.TaskConfig{
Job: &aurora.JobKey{},
MesosFetcherUris: make([]*aurora.MesosFetcherURI, 0),
Metadata: make([]*aurora.Metadata, 0),
Constraints: make([]*aurora.Constraint, 0),
// Container is a Union so one container field must be set. Set Mesos by default.
Container: NewMesosContainer().Build(),
Resources: []*aurora.Resource{numCpus, ramMb, diskMb},
},
resources: resources,
portCount: 0}
}
// Helper method to convert aurora.TaskConfig to gorealis AuroraTask type
func TaskFromThrift(config *aurora.TaskConfig) *AuroraTask {
newTask := NewTask()
// Pass values using receivers as much as possible
newTask.
Environment(config.Job.Environment).
Role(config.Job.Role).
Name(config.Job.Name).
MaxFailure(config.MaxTaskFailures).
IsService(config.IsService).
Priority(config.Priority)
if config.Tier != nil {
newTask.Tier(*config.Tier)
}
if config.Production != nil {
newTask.Production(*config.Production)
}
if config.ExecutorConfig != nil {
newTask.
ExecutorName(config.ExecutorConfig.Name).
ExecutorData(config.ExecutorConfig.Data)
}
if config.PartitionPolicy != nil {
newTask.PartitionPolicy(
aurora.PartitionPolicy{
Reschedule: config.PartitionPolicy.Reschedule,
DelaySecs: thrift.Int64Ptr(*config.PartitionPolicy.DelaySecs),
})
}
// Make a deep copy of the task's container
if config.Container != nil {
if config.Container.Mesos != nil {
mesosContainer := NewMesosContainer()
if config.Container.Mesos.Image != nil {
if config.Container.Mesos.Image.Appc != nil {
mesosContainer.AppcImage(config.Container.Mesos.Image.Appc.Name, config.Container.Mesos.Image.Appc.ImageId)
} else if config.Container.Mesos.Image.Docker != nil {
mesosContainer.DockerImage(config.Container.Mesos.Image.Docker.Name, config.Container.Mesos.Image.Docker.Tag)
}
}
for _, vol := range config.Container.Mesos.Volumes {
mesosContainer.AddVolume(vol.ContainerPath, vol.HostPath, vol.Mode)
}
newTask.Container(mesosContainer)
} else if config.Container.Docker != nil {
dockerContainer := NewDockerContainer()
dockerContainer.Image(config.Container.Docker.Image)
for _, param := range config.Container.Docker.Parameters {
dockerContainer.AddParameter(param.Name, param.Value)
}
newTask.Container(dockerContainer)
}
}
// Copy all ports
for _, resource := range config.Resources {
// Copy only ports. Set CPU, RAM, DISK, and GPU
if resource != nil {
if resource.NamedPort != nil {
newTask.task.Resources = append(
newTask.task.Resources,
&aurora.Resource{NamedPort: thrift.StringPtr(*resource.NamedPort)},
)
newTask.portCount++
}
if resource.RamMb != nil {
newTask.RAM(*resource.RamMb)
}
if resource.NumCpus != nil {
newTask.CPU(*resource.NumCpus)
}
if resource.DiskMb != nil {
newTask.Disk(*resource.DiskMb)
}
if resource.NumGpus != nil {
newTask.GPU(*resource.NumGpus)
}
}
}
// Copy constraints
for _, constraint := range config.Constraints {
if constraint != nil && constraint.Constraint != nil {
newConstraint := aurora.Constraint{Name: constraint.Name}
taskConstraint := constraint.Constraint
if taskConstraint.Limit != nil {
newConstraint.Constraint = &aurora.TaskConstraint{
Limit: &aurora.LimitConstraint{Limit: taskConstraint.Limit.Limit},
}
newTask.task.Constraints = append(newTask.task.Constraints, &newConstraint)
} else if taskConstraint.Value != nil {
values := make([]string, 0)
for _, val := range taskConstraint.Value.Values {
values = append(values, val)
}
newConstraint.Constraint = &aurora.TaskConstraint{
Value: &aurora.ValueConstraint{Negated: taskConstraint.Value.Negated, Values: values}}
newTask.task.Constraints = append(newTask.task.Constraints, &newConstraint)
}
}
}
// Copy labels
for _, label := range config.Metadata {
newTask.task.Metadata = append(newTask.task.Metadata, &aurora.Metadata{Key: label.Key, Value: label.Value})
}
// Copy Mesos fetcher URIs
for _, uri := range config.MesosFetcherUris {
newTask.task.MesosFetcherUris = append(
newTask.task.MesosFetcherUris,
&aurora.MesosFetcherURI{
Value: uri.Value,
Extract: thrift.BoolPtr(*uri.Extract),
Cache: thrift.BoolPtr(*uri.Cache),
})
}
return newTask
}
// Set AuroraTask Key environment.
func (t *AuroraTask) Environment(env string) *AuroraTask {
t.task.Job.Environment = env
return t
}
// Set AuroraTask Key Role.
func (t *AuroraTask) Role(role string) *AuroraTask {
t.task.Job.Role = role
return t
}
// Set AuroraTask Key Name.
func (t *AuroraTask) Name(name string) *AuroraTask {
t.task.Job.Name = name
return t
}
// Set name of the executor that will the task will be configured to.
func (t *AuroraTask) ExecutorName(name string) *AuroraTask {
if t.task.ExecutorConfig == nil {
t.task.ExecutorConfig = aurora.NewExecutorConfig()
}
t.task.ExecutorConfig.Name = name
return t
}
// Will be included as part of entire task inside the scheduler that will be serialized.
func (t *AuroraTask) ExecutorData(data string) *AuroraTask {
if t.task.ExecutorConfig == nil {
t.task.ExecutorConfig = aurora.NewExecutorConfig()
}
t.task.ExecutorConfig.Data = data
return t
}
func (t *AuroraTask) CPU(cpus float64) *AuroraTask {
*t.resources[CPU].NumCpus = cpus
return t
}
func (t *AuroraTask) RAM(ram int64) *AuroraTask {
*t.resources[RAM].RamMb = ram
return t
}
func (t *AuroraTask) Disk(disk int64) *AuroraTask {
*t.resources[DISK].DiskMb = disk
return t
}
func (t *AuroraTask) GPU(gpu int64) *AuroraTask {
// GPU resource must be set explicitly since the scheduler by default
// rejects jobs with GPU resources attached to it.
if _, ok := t.resources[GPU]; !ok {
t.resources[GPU] = &aurora.Resource{}
t.task.Resources = append(t.task.Resources, t.resources[GPU])
}
t.resources[GPU].NumGpus = &gpu
return t
}
func (t *AuroraTask) Tier(tier string) *AuroraTask {
t.task.Tier = &tier
return t
}
// How many failures to tolerate before giving up.
func (t *AuroraTask) MaxFailure(maxFail int32) *AuroraTask {
t.task.MaxTaskFailures = maxFail
return t
}
// Restart the job's tasks if they fail
func (t *AuroraTask) IsService(isService bool) *AuroraTask {
t.task.IsService = isService
return t
}
//set priority for preemption or priority-queueing
func (t *AuroraTask) Priority(priority int32) *AuroraTask {
t.task.Priority = priority
return t
}
func (t *AuroraTask) Production(production bool) *AuroraTask {
t.task.Production = &production
return t
}
// Add a list of URIs with the same extract and cache configuration. Scheduler must have
// --enable_mesos_fetcher flag enabled. Currently there is no duplicate detection.
func (t *AuroraTask) AddURIs(extract bool, cache bool, values ...string) *AuroraTask {
for _, value := range values {
t.task.MesosFetcherUris = append(
t.task.MesosFetcherUris,
&aurora.MesosFetcherURI{Value: value, Extract: &extract, Cache: &cache})
}
return t
}
// Adds a Mesos label to the job. Note that Aurora will add the
// prefix "org.apache.aurora.metadata." to the beginning of each key.
func (t *AuroraTask) AddLabel(key string, value string) *AuroraTask {
t.task.Metadata = append(t.task.Metadata, &aurora.Metadata{Key: key, Value: value})
return t
}
// Add a named port to the job configuration These are random ports as it's
// not currently possible to request specific ports using Aurora.
func (t *AuroraTask) AddNamedPorts(names ...string) *AuroraTask {
t.portCount += len(names)
for _, name := range names {
t.task.Resources = append(t.task.Resources, &aurora.Resource{NamedPort: &name})
}
return t
}
// Adds a request for a number of ports to the job configuration. The names chosen for these ports
// will be org.apache.aurora.port.X, where X is the current port count for the job configuration
// starting at 0. These are random ports as it's not currently possible to request
// specific ports using Aurora.
func (t *AuroraTask) AddPorts(num int) *AuroraTask {
start := t.portCount
t.portCount += num
for i := start; i < t.portCount; i++ {
portName := portPrefix + strconv.Itoa(i)
t.task.Resources = append(t.task.Resources, &aurora.Resource{NamedPort: &portName})
}
return t
}
// From Aurora Docs:
// Add a Value constraint
// name - Mesos slave attribute that the constraint is matched against.
// If negated = true , treat this as a 'not' - to avoid specific values.
// Values - list of values we look for in attribute name
func (t *AuroraTask) AddValueConstraint(name string, negated bool, values ...string) *AuroraTask {
t.task.Constraints = append(t.task.Constraints,
&aurora.Constraint{
Name: name,
Constraint: &aurora.TaskConstraint{
Value: &aurora.ValueConstraint{
Negated: negated,
Values: values,
},
Limit: nil,
},
})
return t
}
// From Aurora Docs:
// A constraint that specifies the maximum number of active tasks on a host with
// a matching attribute that may be scheduled simultaneously.
func (t *AuroraTask) AddLimitConstraint(name string, limit int32) *AuroraTask {
t.task.Constraints = append(t.task.Constraints,
&aurora.Constraint{
Name: name,
Constraint: &aurora.TaskConstraint{
Value: nil,
Limit: &aurora.LimitConstraint{Limit: limit},
},
})
return t
}
// From Aurora Docs:
// dedicated attribute. Aurora treats this specially, and only allows matching jobs
// to run on these machines, and will only schedule matching jobs on these machines.
// When a job is created, the scheduler requires that the $role component matches
// the role field in the job configuration, and will reject the job creation otherwise.
// A wildcard (*) may be used for the role portion of the dedicated attribute, which
// will allow any owner to elect for a job to run on the host(s)
func (t *AuroraTask) AddDedicatedConstraint(role, name string) *AuroraTask {
t.AddValueConstraint(dedicated, false, role+"/"+name)
return t
}
// Set a container to run for the job configuration to run.
func (t *AuroraTask) Container(container Container) *AuroraTask {
t.task.Container = container.Build()
return t
}
func (t *AuroraTask) TaskConfig() *aurora.TaskConfig {
return t.task
}
func (t *AuroraTask) JobKey() aurora.JobKey {
return *t.task.Job
}
func (t *AuroraTask) Clone() *AuroraTask {
newTask := TaskFromThrift(t.task)
if t.thermos != nil {
newTask.ThermosExecutor(*t.thermos.Clone())
}
return newTask
}
func (t *AuroraTask) ThermosExecutor(thermos ThermosExecutor) *AuroraTask {
t.thermos = &thermos
return t
}
func (t *AuroraTask) BuildThermosPayload() error {
if t.thermos != nil {
// Set the correct resources
if t.resources[CPU].NumCpus != nil {
t.thermos.cpu(*t.resources[CPU].NumCpus)
}
if t.resources[RAM].RamMb != nil {
t.thermos.ram(*t.resources[RAM].RamMb)
}
if t.resources[DISK].DiskMb != nil {
t.thermos.disk(*t.resources[DISK].DiskMb)
}
if t.resources[GPU] != nil && t.resources[GPU].NumGpus != nil {
t.thermos.gpu(*t.resources[GPU].NumGpus)
}
payload, err := json.Marshal(t.thermos)
if err != nil {
return err
}
t.ExecutorName(aurora.AURORA_EXECUTOR_NAME)
t.ExecutorData(string(payload))
}
return nil
}
// Set a partition policy for the job configuration to implement.
func (t *AuroraTask) PartitionPolicy(policy aurora.PartitionPolicy) *AuroraTask {
t.task.PartitionPolicy = &policy
return t
}

View file

@ -1,59 +0,0 @@
/**
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package realis_test
import (
"testing"
realis "github.com/aurora-scheduler/gorealis/v2"
"github.com/aurora-scheduler/gorealis/v2/gen-go/apache/aurora"
"github.com/stretchr/testify/assert"
)
func TestAuroraTask_Clone(t *testing.T) {
task0 := realis.NewTask().
Environment("development").
Role("ubuntu").
Name("this_is_a_test").
ExecutorName(aurora.AURORA_EXECUTOR_NAME).
ExecutorData("{fake:payload}").
CPU(10).
RAM(643).
Disk(1000).
IsService(true).
Priority(1).
Production(false).
AddPorts(10).
Tier("preferred").
MaxFailure(23).
AddURIs(true, true, "testURI").
AddLabel("Test", "Value").
AddNamedPorts("test").
AddValueConstraint("test", false, "testing").
AddLimitConstraint("test_limit", 1).
AddDedicatedConstraint("ubuntu", "name").
Container(realis.NewDockerContainer().AddParameter("hello", "world").Image("testImg"))
task1 := task0.Clone()
assert.EqualValues(t, task0, task1, "Clone does not return the correct deep copy of AuroraTask")
task0.Container(realis.NewMesosContainer().
AppcImage("test", "testing").
AddVolume("test", "test", aurora.Mode_RW))
task2 := task0.Clone()
assert.EqualValues(t, task0, task2, "Clone does not return the correct deep copy of AuroraTask")
}

View file

@ -1,195 +0,0 @@
/**
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package realis
import "encoding/json"
type ThermosExecutor struct {
Task ThermosTask `json:"task""`
order *ThermosConstraint `json:"-"`
}
type ThermosTask struct {
Processes map[string]*ThermosProcess `json:"processes"`
Constraints []*ThermosConstraint `json:"constraints"`
Resources thermosResources `json:"resources"`
}
type ThermosConstraint struct {
Order []string `json:"order,omitempty"`
}
// This struct should always be controlled by the Aurora job struct.
// Therefore it is private.
type thermosResources struct {
CPU *float64 `json:"cpu,omitempty"`
Disk *int64 `json:"disk,omitempty"`
RAM *int64 `json:"ram,omitempty"`
GPU *int64 `json:"gpu,omitempty"`
}
type ThermosProcess struct {
Name string `json:"name"`
Cmdline string `json:"cmdline"`
Daemon bool `json:"daemon"`
Ephemeral bool `json:"ephemeral"`
MaxFailures int `json:"max_failures"`
MinDuration int `json:"min_duration"`
Final bool `json:"final"`
}
func NewThermosProcess(name, command string) ThermosProcess {
return ThermosProcess{
Name: name,
Cmdline: command,
MaxFailures: 1,
Daemon: false,
Ephemeral: false,
MinDuration: 5,
Final: false}
}
// Processes must have unique names. Adding a process whose name already exists will
// result in overwriting the previous version of the process.
func (t *ThermosExecutor) AddProcess(process ThermosProcess) *ThermosExecutor {
if len(t.Task.Processes) == 0 {
t.Task.Processes = make(map[string]*ThermosProcess, 0)
}
t.Task.Processes[process.Name] = &process
// Add Process to order
t.addToOrder(process.Name)
return t
}
// Only constraint that should be added for now is the order of execution, therefore this
// receiver is private.
func (t *ThermosExecutor) addConstraint(constraint *ThermosConstraint) *ThermosExecutor {
if len(t.Task.Constraints) == 0 {
t.Task.Constraints = make([]*ThermosConstraint, 0)
}
t.Task.Constraints = append(t.Task.Constraints, constraint)
return t
}
// Order in which the Processes should be executed. Index 0 will be executed first, index N will be executed last.
func (t *ThermosExecutor) ProcessOrder(order ...string) *ThermosExecutor {
if t.order == nil {
t.order = &ThermosConstraint{}
t.addConstraint(t.order)
}
t.order.Order = order
return t
}
// Add Process to execution order. By default this is a FIFO setup. Custom order can be given by overriding
// with ProcessOrder
func (t *ThermosExecutor) addToOrder(name string) {
if t.order == nil {
t.order = &ThermosConstraint{Order: make([]string, 0)}
t.addConstraint(t.order)
}
t.order.Order = append(t.order.Order, name)
}
// Ram is determined by the job object.
func (t *ThermosExecutor) ram(ram int64) {
// Convert from bytes to MiB
ram *= 1024 ^ 2
t.Task.Resources.RAM = &ram
}
// Disk is determined by the job object.
func (t *ThermosExecutor) disk(disk int64) {
// Convert from bytes to MiB
disk *= 1024 ^ 2
t.Task.Resources.Disk = &disk
}
// CPU is determined by the job object.
func (t *ThermosExecutor) cpu(cpu float64) {
t.Task.Resources.CPU = &cpu
}
// GPU is determined by the job object.
func (t *ThermosExecutor) gpu(gpu int64) {
t.Task.Resources.GPU = &gpu
}
// Deep copy of Thermos executor
func (t *ThermosExecutor) Clone() *ThermosExecutor {
tNew := ThermosExecutor{}
if t.order != nil {
tNew.order = &ThermosConstraint{Order: t.order.Order}
tNew.addConstraint(tNew.order)
}
tNew.Task.Processes = make(map[string]*ThermosProcess)
for name, process := range t.Task.Processes {
newProcess := *process
tNew.Task.Processes[name] = &newProcess
}
tNew.Task.Resources = t.Task.Resources
return &tNew
}
type thermosTaskJSON struct {
Processes []*ThermosProcess `json:"processes"`
Constraints []*ThermosConstraint `json:"constraints"`
Resources thermosResources `json:"resources"`
}
// Custom Marshaling for Thermos Task to match what Thermos expects
func (t *ThermosTask) MarshalJSON() ([]byte, error) {
// Convert map to array to match what Thermos expects
processes := make([]*ThermosProcess, 0)
for _, process := range t.Processes {
processes = append(processes, process)
}
return json.Marshal(&thermosTaskJSON{
Processes: processes,
Constraints: t.Constraints,
Resources: t.Resources,
})
}
// Custom Unmarshaling to match what Thermos would contain
func (t *ThermosTask) UnmarshalJSON(data []byte) error {
// Thermos format
aux := &thermosTaskJSON{}
if err := json.Unmarshal(data, &aux); err != nil {
return err
}
processes := make(map[string]*ThermosProcess)
for _, process := range aux.Processes {
processes[process.Name] = process
}
return nil
}

View file

@ -1,71 +0,0 @@
package realis
import (
"encoding/json"
"testing"
"github.com/apache/thrift/lib/go/thrift"
"github.com/stretchr/testify/assert"
)
func TestThermosTask(t *testing.T) {
// Test that we can successfully deserialize a minimum subset of an Aurora generated thermos payload
thermosJSON := []byte(
`{
"task": {
"processes": [
{
"daemon": false,
"name": "hello",
"ephemeral": false,
"max_failures": 1,
"min_duration": 5,
"cmdline": "\n while true; do\n echo hello world from gorealis\n sleep 10\n done\n ",
"final": false
}
],
"resources": {
"gpu": 0,
"disk": 134217728,
"ram": 134217728,
"cpu": 1.1
},
"constraints": [
{
"order": [
"hello"
]
}
]
}
}`)
thermos := ThermosExecutor{}
err := json.Unmarshal(thermosJSON, &thermos)
assert.NoError(t, err)
process := &ThermosProcess{
Daemon: false,
Name: "hello",
Ephemeral: false,
MaxFailures: 1,
MinDuration: 5,
Cmdline: "\n while true; do\n echo hello world from gorealis\n sleep 10\n done\n ",
Final: false,
}
constraint := &ThermosConstraint{Order: []string{process.Name}}
thermosExpected := ThermosExecutor{
Task: ThermosTask{
Processes: map[string]*ThermosProcess{process.Name: process},
Constraints: []*ThermosConstraint{constraint},
Resources: thermosResources{CPU: thrift.Float64Ptr(1.1),
Disk: thrift.Int64Ptr(134217728),
RAM: thrift.Int64Ptr(134217728),
GPU: thrift.Int64Ptr(0)}}}
assert.ObjectsAreEqualValues(thermosExpected, thermos)
}

188
updatejob.go Normal file
View file

@ -0,0 +1,188 @@
/**
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package realis
import (
"github.com/paypal/gorealis/gen-go/apache/aurora"
)
// UpdateJob is a structure to collect all information required to create job update.
type UpdateJob struct {
Job // SetInstanceCount for job is hidden, access via full qualifier
req *aurora.JobUpdateRequest
}
// NewDefaultUpdateJob creates an UpdateJob object with opinionated default settings.
func NewDefaultUpdateJob(config *aurora.TaskConfig) *UpdateJob {
req := aurora.NewJobUpdateRequest()
req.TaskConfig = config
req.Settings = NewUpdateSettings()
job, ok := NewJob().(*AuroraJob)
if !ok {
// This should never happen but it is here as a safeguard
return nil
}
job.jobConfig.TaskConfig = config
// Rebuild resource map from TaskConfig
for _, ptr := range config.Resources {
if ptr.NumCpus != nil {
job.resources[CPU].NumCpus = ptr.NumCpus
continue // Guard against Union violations that Go won't enforce
}
if ptr.RamMb != nil {
job.resources[RAM].RamMb = ptr.RamMb
continue
}
if ptr.DiskMb != nil {
job.resources[DISK].DiskMb = ptr.DiskMb
continue
}
if ptr.NumGpus != nil {
job.resources[GPU] = &aurora.Resource{NumGpus: ptr.NumGpus}
continue
}
}
// Mirrors defaults set by Pystachio
req.Settings.UpdateGroupSize = 1
req.Settings.WaitForBatchCompletion = false
req.Settings.MinWaitInInstanceRunningMs = 45000
req.Settings.MaxPerInstanceFailures = 0
req.Settings.MaxFailedInstances = 0
req.Settings.RollbackOnFailure = true
//TODO(rdelvalle): Deep copy job struct to avoid unexpected behavior
return &UpdateJob{Job: job, req: req}
}
// NewUpdateJob creates an UpdateJob object wihtout default settings.
func NewUpdateJob(config *aurora.TaskConfig, settings *aurora.JobUpdateSettings) *UpdateJob {
req := aurora.NewJobUpdateRequest()
req.TaskConfig = config
req.Settings = settings
job, ok := NewJob().(*AuroraJob)
if !ok {
// This should never happen but it is here as a safeguard
return nil
}
job.jobConfig.TaskConfig = config
// Rebuild resource map from TaskConfig
for _, ptr := range config.Resources {
if ptr.NumCpus != nil {
job.resources[CPU].NumCpus = ptr.NumCpus
continue // Guard against Union violations that Go won't enforce
}
if ptr.RamMb != nil {
job.resources[RAM].RamMb = ptr.RamMb
continue
}
if ptr.DiskMb != nil {
job.resources[DISK].DiskMb = ptr.DiskMb
continue
}
if ptr.NumGpus != nil {
job.resources[GPU] = &aurora.Resource{}
job.resources[GPU].NumGpus = ptr.NumGpus
continue // Guard against Union violations that Go won't enforce
}
}
//TODO(rdelvalle): Deep copy job struct to avoid unexpected behavior
return &UpdateJob{Job: job, req: req}
}
// InstanceCount sets instance count the job will have after the update.
func (u *UpdateJob) InstanceCount(inst int32) *UpdateJob {
u.req.InstanceCount = inst
return u
}
// BatchSize sets the max number of instances being updated at any given moment.
func (u *UpdateJob) BatchSize(size int32) *UpdateJob {
u.req.Settings.UpdateGroupSize = size
return u
}
// WatchTime sets the minimum number of seconds a shard must remain in RUNNING state before considered a success.
func (u *UpdateJob) WatchTime(ms int32) *UpdateJob {
u.req.Settings.MinWaitInInstanceRunningMs = ms
return u
}
// WaitForBatchCompletion configures the job update to wait for all instances in a group to be done before moving on.
func (u *UpdateJob) WaitForBatchCompletion(batchWait bool) *UpdateJob {
u.req.Settings.WaitForBatchCompletion = batchWait
return u
}
// MaxPerInstanceFailures sets the max number of instance failures to tolerate before marking instance as FAILED.
func (u *UpdateJob) MaxPerInstanceFailures(inst int32) *UpdateJob {
u.req.Settings.MaxPerInstanceFailures = inst
return u
}
// MaxFailedInstances sets the max number of FAILED instances to tolerate before terminating the update.
func (u *UpdateJob) MaxFailedInstances(inst int32) *UpdateJob {
u.req.Settings.MaxFailedInstances = inst
return u
}
// RollbackOnFail configure the job to rollback automatically after a job update fails.
func (u *UpdateJob) RollbackOnFail(rollback bool) *UpdateJob {
u.req.Settings.RollbackOnFailure = rollback
return u
}
// NewUpdateSettings return an opinionated set of job update settings.
func (u *UpdateJob) BatchUpdateStrategy(strategy aurora.BatchJobUpdateStrategy) *UpdateJob {
u.req.Settings.UpdateStrategy = &aurora.JobUpdateStrategy{BatchStrategy: &strategy}
return u
}
func (u *UpdateJob) QueueUpdateStrategy(strategy aurora.QueueJobUpdateStrategy) *UpdateJob {
u.req.Settings.UpdateStrategy = &aurora.JobUpdateStrategy{QueueStrategy: &strategy}
return u
}
func (u *UpdateJob) VariableBatchStrategy(strategy aurora.VariableBatchJobUpdateStrategy) *UpdateJob {
u.req.Settings.UpdateStrategy = &aurora.JobUpdateStrategy{VarBatchStrategy: &strategy}
return u
}
func NewUpdateSettings() *aurora.JobUpdateSettings {
us := new(aurora.JobUpdateSettings)
// Mirrors defaults set by Pystachio
us.UpdateGroupSize = 1
us.WaitForBatchCompletion = false
us.MinWaitInInstanceRunningMs = 45000
us.MaxPerInstanceFailures = 0
us.MaxFailedInstances = 0
us.RollbackOnFailure = true
return us
}

70
util.go
View file

@ -4,15 +4,40 @@ import (
"net/url"
"strings"
"github.com/aurora-scheduler/gorealis/v2/gen-go/apache/aurora"
"github.com/paypal/gorealis/gen-go/apache/aurora"
"github.com/pkg/errors"
)
const apiPath = "/api"
// ActiveStates - States a task may be in when active.
var ActiveStates = make(map[aurora.ScheduleStatus]bool)
// SlaveAssignedStates - States a task may be in when it has already been assigned to a Mesos agent.
var SlaveAssignedStates = make(map[aurora.ScheduleStatus]bool)
// LiveStates - States a task may be in when it is live (e.g. able to take traffic)
var LiveStates = make(map[aurora.ScheduleStatus]bool)
// TerminalStates - Set of states a task may not transition away from.
var TerminalStates = make(map[aurora.ScheduleStatus]bool)
// ActiveJobUpdateStates - States a Job Update may be in where it is considered active.
var ActiveJobUpdateStates = make(map[aurora.JobUpdateStatus]bool)
// TerminalJobUpdateStates returns a slice containing all the terminal states an update may end up in.
// This is a function in order to avoid having a slice that can be accidentally mutated.
func TerminalUpdateStates() []aurora.JobUpdateStatus {
return []aurora.JobUpdateStatus{
aurora.JobUpdateStatus_ROLLED_FORWARD,
aurora.JobUpdateStatus_ROLLED_BACK,
aurora.JobUpdateStatus_ABORTED,
aurora.JobUpdateStatus_ERROR,
aurora.JobUpdateStatus_FAILED,
}
}
// AwaitingPulseJobUpdateStates - States a job update may be in where it is waiting for a pulse.
var AwaitingPulseJobUpdateStates = make(map[aurora.JobUpdateStatus]bool)
func init() {
@ -40,26 +65,14 @@ func init() {
}
}
// TerminalUpdateStates returns a slice containing all the terminal states an update may be in.
// This is a function in order to avoid having a slice that can be accidentally mutated.
func TerminalUpdateStates() []aurora.JobUpdateStatus {
return []aurora.JobUpdateStatus{
aurora.JobUpdateStatus_ROLLED_FORWARD,
aurora.JobUpdateStatus_ROLLED_BACK,
aurora.JobUpdateStatus_ABORTED,
aurora.JobUpdateStatus_ERROR,
aurora.JobUpdateStatus_FAILED,
}
}
func validateAuroraAddress(address string) (string, error) {
func validateAuroraURL(location string) (string, error) {
// If no protocol defined, assume http
if !strings.Contains(address, "://") {
address = "http://" + address
if !strings.Contains(location, "://") {
location = "http://" + location
}
u, err := url.Parse(address)
u, err := url.Parse(location)
if err != nil {
return "", errors.Wrap(err, "error parsing url")
@ -79,7 +92,8 @@ func validateAuroraAddress(address string) (string, error) {
return "", errors.Errorf("only protocols http and https are supported %v\n", u.Scheme)
}
if u.Path != "/api" {
// This could theoretically be elsewhwere but we'll be strict for the sake of simplicty
if u.Path != apiPath {
return "", errors.Errorf("expected /api path %v\n", u.Path)
}
@ -104,23 +118,3 @@ func calculateCurrentBatch(updatingInstances int32, batchSizes []int32) int {
}
return batchCount
}
func ResourcesToMap(resources []*aurora.Resource) map[string]float64 {
result := map[string]float64{}
for _, resource := range resources {
if resource.NumCpus != nil {
result["cpus"] += *resource.NumCpus
} else if resource.RamMb != nil {
result["mem"] += float64(*resource.RamMb)
} else if resource.DiskMb != nil {
result["disk"] += float64(*resource.DiskMb)
} else if resource.NamedPort != nil {
result["ports"]++
} else if resource.NumGpus != nil {
result["gpus"] += float64(*resource.NumGpus)
}
}
return result
}

View file

@ -20,6 +20,50 @@ import (
"github.com/stretchr/testify/assert"
)
func TestAuroraURLValidator(t *testing.T) {
t.Run("badURL", func(t *testing.T) {
url, err := validateAuroraURL("http://badurl.com/badpath")
assert.Empty(t, url)
assert.Error(t, err)
})
t.Run("URLHttp", func(t *testing.T) {
url, err := validateAuroraURL("http://goodurl.com:8081/api")
assert.Equal(t, "http://goodurl.com:8081/api", url)
assert.NoError(t, err)
})
t.Run("URLHttps", func(t *testing.T) {
url, err := validateAuroraURL("https://goodurl.com:8081/api")
assert.Equal(t, "https://goodurl.com:8081/api", url)
assert.NoError(t, err)
})
t.Run("URLNoPath", func(t *testing.T) {
url, err := validateAuroraURL("http://goodurl.com:8081")
assert.Equal(t, "http://goodurl.com:8081/api", url)
assert.NoError(t, err)
})
t.Run("ipAddrNoPath", func(t *testing.T) {
url, err := validateAuroraURL("http://192.168.1.33:8081")
assert.Equal(t, "http://192.168.1.33:8081/api", url)
assert.NoError(t, err)
})
t.Run("URLNoProtocol", func(t *testing.T) {
url, err := validateAuroraURL("goodurl.com:8081/api")
assert.Equal(t, "http://goodurl.com:8081/api", url)
assert.NoError(t, err)
})
t.Run("URLNoProtocolNoPathNoPort", func(t *testing.T) {
url, err := validateAuroraURL("goodurl.com")
assert.Equal(t, "http://goodurl.com:8081/api", url)
assert.NoError(t, err)
})
}
func TestCurrentBatchCalculator(t *testing.T) {
t.Run("singleBatchOverflow", func(t *testing.T) {
curBatch := calculateCurrentBatch(10, []int32{2})

View file

@ -0,0 +1,164 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
const (
UNKNOWN_APPLICATION_EXCEPTION = 0
UNKNOWN_METHOD = 1
INVALID_MESSAGE_TYPE_EXCEPTION = 2
WRONG_METHOD_NAME = 3
BAD_SEQUENCE_ID = 4
MISSING_RESULT = 5
INTERNAL_ERROR = 6
PROTOCOL_ERROR = 7
)
var defaultApplicationExceptionMessage = map[int32]string{
UNKNOWN_APPLICATION_EXCEPTION: "unknown application exception",
UNKNOWN_METHOD: "unknown method",
INVALID_MESSAGE_TYPE_EXCEPTION: "invalid message type",
WRONG_METHOD_NAME: "wrong method name",
BAD_SEQUENCE_ID: "bad sequence ID",
MISSING_RESULT: "missing result",
INTERNAL_ERROR: "unknown internal error",
PROTOCOL_ERROR: "unknown protocol error",
}
// Application level Thrift exception
type TApplicationException interface {
TException
TypeId() int32
Read(iprot TProtocol) error
Write(oprot TProtocol) error
}
type tApplicationException struct {
message string
type_ int32
}
func (e tApplicationException) Error() string {
if e.message != "" {
return e.message
}
return defaultApplicationExceptionMessage[e.type_]
}
func NewTApplicationException(type_ int32, message string) TApplicationException {
return &tApplicationException{message, type_}
}
func (p *tApplicationException) TypeId() int32 {
return p.type_
}
func (p *tApplicationException) Read(iprot TProtocol) error {
// TODO: this should really be generated by the compiler
_, err := iprot.ReadStructBegin()
if err != nil {
return err
}
message := ""
type_ := int32(UNKNOWN_APPLICATION_EXCEPTION)
for {
_, ttype, id, err := iprot.ReadFieldBegin()
if err != nil {
return err
}
if ttype == STOP {
break
}
switch id {
case 1:
if ttype == STRING {
if message, err = iprot.ReadString(); err != nil {
return err
}
} else {
if err = SkipDefaultDepth(iprot, ttype); err != nil {
return err
}
}
case 2:
if ttype == I32 {
if type_, err = iprot.ReadI32(); err != nil {
return err
}
} else {
if err = SkipDefaultDepth(iprot, ttype); err != nil {
return err
}
}
default:
if err = SkipDefaultDepth(iprot, ttype); err != nil {
return err
}
}
if err = iprot.ReadFieldEnd(); err != nil {
return err
}
}
if err := iprot.ReadStructEnd(); err != nil {
return err
}
p.message = message
p.type_ = type_
return nil
}
func (p *tApplicationException) Write(oprot TProtocol) (err error) {
err = oprot.WriteStructBegin("TApplicationException")
if len(p.Error()) > 0 {
err = oprot.WriteFieldBegin("message", STRING, 1)
if err != nil {
return
}
err = oprot.WriteString(p.Error())
if err != nil {
return
}
err = oprot.WriteFieldEnd()
if err != nil {
return
}
}
err = oprot.WriteFieldBegin("type", I32, 2)
if err != nil {
return
}
err = oprot.WriteI32(p.type_)
if err != nil {
return
}
err = oprot.WriteFieldEnd()
if err != nil {
return
}
err = oprot.WriteFieldStop()
if err != nil {
return
}
err = oprot.WriteStructEnd()
return
}

View file

@ -0,0 +1,41 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"testing"
)
func TestTApplicationException(t *testing.T) {
exc := NewTApplicationException(UNKNOWN_APPLICATION_EXCEPTION, "")
if exc.Error() != defaultApplicationExceptionMessage[UNKNOWN_APPLICATION_EXCEPTION] {
t.Fatalf("Expected empty string for exception but found '%s'", exc.Error())
}
if exc.TypeId() != UNKNOWN_APPLICATION_EXCEPTION {
t.Fatalf("Expected type UNKNOWN for exception but found '%v'", exc.TypeId())
}
exc = NewTApplicationException(WRONG_METHOD_NAME, "junk_method")
if exc.Error() != "junk_method" {
t.Fatalf("Expected 'junk_method' for exception but found '%s'", exc.Error())
}
if exc.TypeId() != WRONG_METHOD_NAME {
t.Fatalf("Expected type WRONG_METHOD_NAME for exception but found '%v'", exc.TypeId())
}
}

View file

@ -0,0 +1,509 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
"io"
"math"
)
type TBinaryProtocol struct {
trans TRichTransport
origTransport TTransport
reader io.Reader
writer io.Writer
strictRead bool
strictWrite bool
buffer [64]byte
}
type TBinaryProtocolFactory struct {
strictRead bool
strictWrite bool
}
func NewTBinaryProtocolTransport(t TTransport) *TBinaryProtocol {
return NewTBinaryProtocol(t, false, true)
}
func NewTBinaryProtocol(t TTransport, strictRead, strictWrite bool) *TBinaryProtocol {
p := &TBinaryProtocol{origTransport: t, strictRead: strictRead, strictWrite: strictWrite}
if et, ok := t.(TRichTransport); ok {
p.trans = et
} else {
p.trans = NewTRichTransport(t)
}
p.reader = p.trans
p.writer = p.trans
return p
}
func NewTBinaryProtocolFactoryDefault() *TBinaryProtocolFactory {
return NewTBinaryProtocolFactory(false, true)
}
func NewTBinaryProtocolFactory(strictRead, strictWrite bool) *TBinaryProtocolFactory {
return &TBinaryProtocolFactory{strictRead: strictRead, strictWrite: strictWrite}
}
func (p *TBinaryProtocolFactory) GetProtocol(t TTransport) TProtocol {
return NewTBinaryProtocol(t, p.strictRead, p.strictWrite)
}
/**
* Writing Methods
*/
func (p *TBinaryProtocol) WriteMessageBegin(name string, typeId TMessageType, seqId int32) error {
if p.strictWrite {
version := uint32(VERSION_1) | uint32(typeId)
e := p.WriteI32(int32(version))
if e != nil {
return e
}
e = p.WriteString(name)
if e != nil {
return e
}
e = p.WriteI32(seqId)
return e
} else {
e := p.WriteString(name)
if e != nil {
return e
}
e = p.WriteByte(int8(typeId))
if e != nil {
return e
}
e = p.WriteI32(seqId)
return e
}
return nil
}
func (p *TBinaryProtocol) WriteMessageEnd() error {
return nil
}
func (p *TBinaryProtocol) WriteStructBegin(name string) error {
return nil
}
func (p *TBinaryProtocol) WriteStructEnd() error {
return nil
}
func (p *TBinaryProtocol) WriteFieldBegin(name string, typeId TType, id int16) error {
e := p.WriteByte(int8(typeId))
if e != nil {
return e
}
e = p.WriteI16(id)
return e
}
func (p *TBinaryProtocol) WriteFieldEnd() error {
return nil
}
func (p *TBinaryProtocol) WriteFieldStop() error {
e := p.WriteByte(STOP)
return e
}
func (p *TBinaryProtocol) WriteMapBegin(keyType TType, valueType TType, size int) error {
e := p.WriteByte(int8(keyType))
if e != nil {
return e
}
e = p.WriteByte(int8(valueType))
if e != nil {
return e
}
e = p.WriteI32(int32(size))
return e
}
func (p *TBinaryProtocol) WriteMapEnd() error {
return nil
}
func (p *TBinaryProtocol) WriteListBegin(elemType TType, size int) error {
e := p.WriteByte(int8(elemType))
if e != nil {
return e
}
e = p.WriteI32(int32(size))
return e
}
func (p *TBinaryProtocol) WriteListEnd() error {
return nil
}
func (p *TBinaryProtocol) WriteSetBegin(elemType TType, size int) error {
e := p.WriteByte(int8(elemType))
if e != nil {
return e
}
e = p.WriteI32(int32(size))
return e
}
func (p *TBinaryProtocol) WriteSetEnd() error {
return nil
}
func (p *TBinaryProtocol) WriteBool(value bool) error {
if value {
return p.WriteByte(1)
}
return p.WriteByte(0)
}
func (p *TBinaryProtocol) WriteByte(value int8) error {
e := p.trans.WriteByte(byte(value))
return NewTProtocolException(e)
}
func (p *TBinaryProtocol) WriteI16(value int16) error {
v := p.buffer[0:2]
binary.BigEndian.PutUint16(v, uint16(value))
_, e := p.writer.Write(v)
return NewTProtocolException(e)
}
func (p *TBinaryProtocol) WriteI32(value int32) error {
v := p.buffer[0:4]
binary.BigEndian.PutUint32(v, uint32(value))
_, e := p.writer.Write(v)
return NewTProtocolException(e)
}
func (p *TBinaryProtocol) WriteI64(value int64) error {
v := p.buffer[0:8]
binary.BigEndian.PutUint64(v, uint64(value))
_, err := p.writer.Write(v)
return NewTProtocolException(err)
}
func (p *TBinaryProtocol) WriteDouble(value float64) error {
return p.WriteI64(int64(math.Float64bits(value)))
}
func (p *TBinaryProtocol) WriteString(value string) error {
e := p.WriteI32(int32(len(value)))
if e != nil {
return e
}
_, err := p.trans.WriteString(value)
return NewTProtocolException(err)
}
func (p *TBinaryProtocol) WriteBinary(value []byte) error {
e := p.WriteI32(int32(len(value)))
if e != nil {
return e
}
_, err := p.writer.Write(value)
return NewTProtocolException(err)
}
/**
* Reading methods
*/
func (p *TBinaryProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqId int32, err error) {
size, e := p.ReadI32()
if e != nil {
return "", typeId, 0, NewTProtocolException(e)
}
if size < 0 {
typeId = TMessageType(size & 0x0ff)
version := int64(int64(size) & VERSION_MASK)
if version != VERSION_1 {
return name, typeId, seqId, NewTProtocolExceptionWithType(BAD_VERSION, fmt.Errorf("Bad version in ReadMessageBegin"))
}
name, e = p.ReadString()
if e != nil {
return name, typeId, seqId, NewTProtocolException(e)
}
seqId, e = p.ReadI32()
if e != nil {
return name, typeId, seqId, NewTProtocolException(e)
}
return name, typeId, seqId, nil
}
if p.strictRead {
return name, typeId, seqId, NewTProtocolExceptionWithType(BAD_VERSION, fmt.Errorf("Missing version in ReadMessageBegin"))
}
name, e2 := p.readStringBody(size)
if e2 != nil {
return name, typeId, seqId, e2
}
b, e3 := p.ReadByte()
if e3 != nil {
return name, typeId, seqId, e3
}
typeId = TMessageType(b)
seqId, e4 := p.ReadI32()
if e4 != nil {
return name, typeId, seqId, e4
}
return name, typeId, seqId, nil
}
func (p *TBinaryProtocol) ReadMessageEnd() error {
return nil
}
func (p *TBinaryProtocol) ReadStructBegin() (name string, err error) {
return
}
func (p *TBinaryProtocol) ReadStructEnd() error {
return nil
}
func (p *TBinaryProtocol) ReadFieldBegin() (name string, typeId TType, seqId int16, err error) {
t, err := p.ReadByte()
typeId = TType(t)
if err != nil {
return name, typeId, seqId, err
}
if t != STOP {
seqId, err = p.ReadI16()
}
return name, typeId, seqId, err
}
func (p *TBinaryProtocol) ReadFieldEnd() error {
return nil
}
var invalidDataLength = NewTProtocolExceptionWithType(INVALID_DATA, errors.New("Invalid data length"))
func (p *TBinaryProtocol) ReadMapBegin() (kType, vType TType, size int, err error) {
k, e := p.ReadByte()
if e != nil {
err = NewTProtocolException(e)
return
}
kType = TType(k)
v, e := p.ReadByte()
if e != nil {
err = NewTProtocolException(e)
return
}
vType = TType(v)
size32, e := p.ReadI32()
if e != nil {
err = NewTProtocolException(e)
return
}
if size32 < 0 {
err = invalidDataLength
return
}
size = int(size32)
return kType, vType, size, nil
}
func (p *TBinaryProtocol) ReadMapEnd() error {
return nil
}
func (p *TBinaryProtocol) ReadListBegin() (elemType TType, size int, err error) {
b, e := p.ReadByte()
if e != nil {
err = NewTProtocolException(e)
return
}
elemType = TType(b)
size32, e := p.ReadI32()
if e != nil {
err = NewTProtocolException(e)
return
}
if size32 < 0 {
err = invalidDataLength
return
}
size = int(size32)
return
}
func (p *TBinaryProtocol) ReadListEnd() error {
return nil
}
func (p *TBinaryProtocol) ReadSetBegin() (elemType TType, size int, err error) {
b, e := p.ReadByte()
if e != nil {
err = NewTProtocolException(e)
return
}
elemType = TType(b)
size32, e := p.ReadI32()
if e != nil {
err = NewTProtocolException(e)
return
}
if size32 < 0 {
err = invalidDataLength
return
}
size = int(size32)
return elemType, size, nil
}
func (p *TBinaryProtocol) ReadSetEnd() error {
return nil
}
func (p *TBinaryProtocol) ReadBool() (bool, error) {
b, e := p.ReadByte()
v := true
if b != 1 {
v = false
}
return v, e
}
func (p *TBinaryProtocol) ReadByte() (int8, error) {
v, err := p.trans.ReadByte()
return int8(v), err
}
func (p *TBinaryProtocol) ReadI16() (value int16, err error) {
buf := p.buffer[0:2]
err = p.readAll(buf)
value = int16(binary.BigEndian.Uint16(buf))
return value, err
}
func (p *TBinaryProtocol) ReadI32() (value int32, err error) {
buf := p.buffer[0:4]
err = p.readAll(buf)
value = int32(binary.BigEndian.Uint32(buf))
return value, err
}
func (p *TBinaryProtocol) ReadI64() (value int64, err error) {
buf := p.buffer[0:8]
err = p.readAll(buf)
value = int64(binary.BigEndian.Uint64(buf))
return value, err
}
func (p *TBinaryProtocol) ReadDouble() (value float64, err error) {
buf := p.buffer[0:8]
err = p.readAll(buf)
value = math.Float64frombits(binary.BigEndian.Uint64(buf))
return value, err
}
func (p *TBinaryProtocol) ReadString() (value string, err error) {
size, e := p.ReadI32()
if e != nil {
return "", e
}
if size < 0 {
err = invalidDataLength
return
}
return p.readStringBody(size)
}
func (p *TBinaryProtocol) ReadBinary() ([]byte, error) {
size, e := p.ReadI32()
if e != nil {
return nil, e
}
if size < 0 {
return nil, invalidDataLength
}
isize := int(size)
buf := make([]byte, isize)
_, err := io.ReadFull(p.trans, buf)
return buf, NewTProtocolException(err)
}
func (p *TBinaryProtocol) Flush(ctx context.Context) (err error) {
return NewTProtocolException(p.trans.Flush(ctx))
}
func (p *TBinaryProtocol) Skip(fieldType TType) (err error) {
return SkipDefaultDepth(p, fieldType)
}
func (p *TBinaryProtocol) Transport() TTransport {
return p.origTransport
}
func (p *TBinaryProtocol) readAll(buf []byte) error {
_, err := io.ReadFull(p.reader, buf)
return NewTProtocolException(err)
}
const readLimit = 32768
func (p *TBinaryProtocol) readStringBody(size int32) (value string, err error) {
if size < 0 {
return "", nil
}
var (
buf bytes.Buffer
e error
b []byte
)
switch {
case int(size) <= len(p.buffer):
b = p.buffer[:size] // avoids allocation for small reads
case int(size) < readLimit:
b = make([]byte, size)
default:
b = make([]byte, readLimit)
}
for size > 0 {
_, e = io.ReadFull(p.trans, b)
buf.Write(b)
if e != nil {
break
}
size -= readLimit
if size < readLimit && size > 0 {
b = b[:size]
}
}
return buf.String(), NewTProtocolException(e)
}

View file

@ -0,0 +1,28 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"testing"
)
func TestReadWriteBinaryProtocol(t *testing.T) {
ReadWriteProtocolTest(t, NewTBinaryProtocolFactoryDefault())
}

View file

@ -0,0 +1,92 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"bufio"
"context"
)
type TBufferedTransportFactory struct {
size int
}
type TBufferedTransport struct {
bufio.ReadWriter
tp TTransport
}
func (p *TBufferedTransportFactory) GetTransport(trans TTransport) (TTransport, error) {
return NewTBufferedTransport(trans, p.size), nil
}
func NewTBufferedTransportFactory(bufferSize int) *TBufferedTransportFactory {
return &TBufferedTransportFactory{size: bufferSize}
}
func NewTBufferedTransport(trans TTransport, bufferSize int) *TBufferedTransport {
return &TBufferedTransport{
ReadWriter: bufio.ReadWriter{
Reader: bufio.NewReaderSize(trans, bufferSize),
Writer: bufio.NewWriterSize(trans, bufferSize),
},
tp: trans,
}
}
func (p *TBufferedTransport) IsOpen() bool {
return p.tp.IsOpen()
}
func (p *TBufferedTransport) Open() (err error) {
return p.tp.Open()
}
func (p *TBufferedTransport) Close() (err error) {
return p.tp.Close()
}
func (p *TBufferedTransport) Read(b []byte) (int, error) {
n, err := p.ReadWriter.Read(b)
if err != nil {
p.ReadWriter.Reader.Reset(p.tp)
}
return n, err
}
func (p *TBufferedTransport) Write(b []byte) (int, error) {
n, err := p.ReadWriter.Write(b)
if err != nil {
p.ReadWriter.Writer.Reset(p.tp)
}
return n, err
}
func (p *TBufferedTransport) Flush(ctx context.Context) error {
if err := p.ReadWriter.Flush(); err != nil {
p.ReadWriter.Writer.Reset(p.tp)
return err
}
return p.tp.Flush(ctx)
}
func (p *TBufferedTransport) RemainingBytes() (num_bytes uint64) {
return p.tp.RemainingBytes()
}

View file

@ -0,0 +1,29 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"testing"
)
func TestBufferedTransport(t *testing.T) {
trans := NewTBufferedTransport(NewTMemoryBuffer(), 10240)
TransportTest(t, trans, trans)
}

View file

@ -0,0 +1,85 @@
package thrift
import (
"context"
"fmt"
)
type TClient interface {
Call(ctx context.Context, method string, args, result TStruct) error
}
type TStandardClient struct {
seqId int32
iprot, oprot TProtocol
}
// TStandardClient implements TClient, and uses the standard message format for Thrift.
// It is not safe for concurrent use.
func NewTStandardClient(inputProtocol, outputProtocol TProtocol) *TStandardClient {
return &TStandardClient{
iprot: inputProtocol,
oprot: outputProtocol,
}
}
func (p *TStandardClient) Send(ctx context.Context, oprot TProtocol, seqId int32, method string, args TStruct) error {
if err := oprot.WriteMessageBegin(method, CALL, seqId); err != nil {
return err
}
if err := args.Write(oprot); err != nil {
return err
}
if err := oprot.WriteMessageEnd(); err != nil {
return err
}
return oprot.Flush(ctx)
}
func (p *TStandardClient) Recv(iprot TProtocol, seqId int32, method string, result TStruct) error {
rMethod, rTypeId, rSeqId, err := iprot.ReadMessageBegin()
if err != nil {
return err
}
if method != rMethod {
return NewTApplicationException(WRONG_METHOD_NAME, fmt.Sprintf("%s: wrong method name", method))
} else if seqId != rSeqId {
return NewTApplicationException(BAD_SEQUENCE_ID, fmt.Sprintf("%s: out of order sequence response", method))
} else if rTypeId == EXCEPTION {
var exception tApplicationException
if err := exception.Read(iprot); err != nil {
return err
}
if err := iprot.ReadMessageEnd(); err != nil {
return err
}
return &exception
} else if rTypeId != REPLY {
return NewTApplicationException(INVALID_MESSAGE_TYPE_EXCEPTION, fmt.Sprintf("%s: invalid message type", method))
}
if err := result.Read(iprot); err != nil {
return err
}
return iprot.ReadMessageEnd()
}
func (p *TStandardClient) Call(ctx context.Context, method string, args, result TStruct) error {
p.seqId++
seqId := p.seqId
if err := p.Send(ctx, p.oprot, seqId, method, args); err != nil {
return err
}
// method is oneway
if result == nil {
return nil
}
return p.Recv(p.iprot, seqId, method, result)
}

View file

@ -0,0 +1,30 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import "context"
type mockProcessor struct {
ProcessFunc func(in, out TProtocol) (bool, TException)
}
func (m *mockProcessor) Process(ctx context.Context, in, out TProtocol) (bool, TException) {
return m.ProcessFunc(in, out)
}

View file

@ -0,0 +1,810 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"context"
"encoding/binary"
"fmt"
"io"
"math"
)
const (
COMPACT_PROTOCOL_ID = 0x082
COMPACT_VERSION = 1
COMPACT_VERSION_MASK = 0x1f
COMPACT_TYPE_MASK = 0x0E0
COMPACT_TYPE_BITS = 0x07
COMPACT_TYPE_SHIFT_AMOUNT = 5
)
type tCompactType byte
const (
COMPACT_BOOLEAN_TRUE = 0x01
COMPACT_BOOLEAN_FALSE = 0x02
COMPACT_BYTE = 0x03
COMPACT_I16 = 0x04
COMPACT_I32 = 0x05
COMPACT_I64 = 0x06
COMPACT_DOUBLE = 0x07
COMPACT_BINARY = 0x08
COMPACT_LIST = 0x09
COMPACT_SET = 0x0A
COMPACT_MAP = 0x0B
COMPACT_STRUCT = 0x0C
)
var (
ttypeToCompactType map[TType]tCompactType
)
func init() {
ttypeToCompactType = map[TType]tCompactType{
STOP: STOP,
BOOL: COMPACT_BOOLEAN_TRUE,
BYTE: COMPACT_BYTE,
I16: COMPACT_I16,
I32: COMPACT_I32,
I64: COMPACT_I64,
DOUBLE: COMPACT_DOUBLE,
STRING: COMPACT_BINARY,
LIST: COMPACT_LIST,
SET: COMPACT_SET,
MAP: COMPACT_MAP,
STRUCT: COMPACT_STRUCT,
}
}
type TCompactProtocolFactory struct{}
func NewTCompactProtocolFactory() *TCompactProtocolFactory {
return &TCompactProtocolFactory{}
}
func (p *TCompactProtocolFactory) GetProtocol(trans TTransport) TProtocol {
return NewTCompactProtocol(trans)
}
type TCompactProtocol struct {
trans TRichTransport
origTransport TTransport
// Used to keep track of the last field for the current and previous structs,
// so we can do the delta stuff.
lastField []int
lastFieldId int
// If we encounter a boolean field begin, save the TField here so it can
// have the value incorporated.
booleanFieldName string
booleanFieldId int16
booleanFieldPending bool
// If we read a field header, and it's a boolean field, save the boolean
// value here so that readBool can use it.
boolValue bool
boolValueIsNotNull bool
buffer [64]byte
}
// Create a TCompactProtocol given a TTransport
func NewTCompactProtocol(trans TTransport) *TCompactProtocol {
p := &TCompactProtocol{origTransport: trans, lastField: []int{}}
if et, ok := trans.(TRichTransport); ok {
p.trans = et
} else {
p.trans = NewTRichTransport(trans)
}
return p
}
//
// Public Writing methods.
//
// Write a message header to the wire. Compact Protocol messages contain the
// protocol version so we can migrate forwards in the future if need be.
func (p *TCompactProtocol) WriteMessageBegin(name string, typeId TMessageType, seqid int32) error {
err := p.writeByteDirect(COMPACT_PROTOCOL_ID)
if err != nil {
return NewTProtocolException(err)
}
err = p.writeByteDirect((COMPACT_VERSION & COMPACT_VERSION_MASK) | ((byte(typeId) << COMPACT_TYPE_SHIFT_AMOUNT) & COMPACT_TYPE_MASK))
if err != nil {
return NewTProtocolException(err)
}
_, err = p.writeVarint32(seqid)
if err != nil {
return NewTProtocolException(err)
}
e := p.WriteString(name)
return e
}
func (p *TCompactProtocol) WriteMessageEnd() error { return nil }
// Write a struct begin. This doesn't actually put anything on the wire. We
// use it as an opportunity to put special placeholder markers on the field
// stack so we can get the field id deltas correct.
func (p *TCompactProtocol) WriteStructBegin(name string) error {
p.lastField = append(p.lastField, p.lastFieldId)
p.lastFieldId = 0
return nil
}
// Write a struct end. This doesn't actually put anything on the wire. We use
// this as an opportunity to pop the last field from the current struct off
// of the field stack.
func (p *TCompactProtocol) WriteStructEnd() error {
p.lastFieldId = p.lastField[len(p.lastField)-1]
p.lastField = p.lastField[:len(p.lastField)-1]
return nil
}
func (p *TCompactProtocol) WriteFieldBegin(name string, typeId TType, id int16) error {
if typeId == BOOL {
// we want to possibly include the value, so we'll wait.
p.booleanFieldName, p.booleanFieldId, p.booleanFieldPending = name, id, true
return nil
}
_, err := p.writeFieldBeginInternal(name, typeId, id, 0xFF)
return NewTProtocolException(err)
}
// The workhorse of writeFieldBegin. It has the option of doing a
// 'type override' of the type header. This is used specifically in the
// boolean field case.
func (p *TCompactProtocol) writeFieldBeginInternal(name string, typeId TType, id int16, typeOverride byte) (int, error) {
// short lastField = lastField_.pop();
// if there's a type override, use that.
var typeToWrite byte
if typeOverride == 0xFF {
typeToWrite = byte(p.getCompactType(typeId))
} else {
typeToWrite = typeOverride
}
// check if we can use delta encoding for the field id
fieldId := int(id)
written := 0
if fieldId > p.lastFieldId && fieldId-p.lastFieldId <= 15 {
// write them together
err := p.writeByteDirect(byte((fieldId-p.lastFieldId)<<4) | typeToWrite)
if err != nil {
return 0, err
}
} else {
// write them separate
err := p.writeByteDirect(typeToWrite)
if err != nil {
return 0, err
}
err = p.WriteI16(id)
written = 1 + 2
if err != nil {
return 0, err
}
}
p.lastFieldId = fieldId
// p.lastField.Push(field.id);
return written, nil
}
func (p *TCompactProtocol) WriteFieldEnd() error { return nil }
func (p *TCompactProtocol) WriteFieldStop() error {
err := p.writeByteDirect(STOP)
return NewTProtocolException(err)
}
func (p *TCompactProtocol) WriteMapBegin(keyType TType, valueType TType, size int) error {
if size == 0 {
err := p.writeByteDirect(0)
return NewTProtocolException(err)
}
_, err := p.writeVarint32(int32(size))
if err != nil {
return NewTProtocolException(err)
}
err = p.writeByteDirect(byte(p.getCompactType(keyType))<<4 | byte(p.getCompactType(valueType)))
return NewTProtocolException(err)
}
func (p *TCompactProtocol) WriteMapEnd() error { return nil }
// Write a list header.
func (p *TCompactProtocol) WriteListBegin(elemType TType, size int) error {
_, err := p.writeCollectionBegin(elemType, size)
return NewTProtocolException(err)
}
func (p *TCompactProtocol) WriteListEnd() error { return nil }
// Write a set header.
func (p *TCompactProtocol) WriteSetBegin(elemType TType, size int) error {
_, err := p.writeCollectionBegin(elemType, size)
return NewTProtocolException(err)
}
func (p *TCompactProtocol) WriteSetEnd() error { return nil }
func (p *TCompactProtocol) WriteBool(value bool) error {
v := byte(COMPACT_BOOLEAN_FALSE)
if value {
v = byte(COMPACT_BOOLEAN_TRUE)
}
if p.booleanFieldPending {
// we haven't written the field header yet
_, err := p.writeFieldBeginInternal(p.booleanFieldName, BOOL, p.booleanFieldId, v)
p.booleanFieldPending = false
return NewTProtocolException(err)
}
// we're not part of a field, so just write the value.
err := p.writeByteDirect(v)
return NewTProtocolException(err)
}
// Write a byte. Nothing to see here!
func (p *TCompactProtocol) WriteByte(value int8) error {
err := p.writeByteDirect(byte(value))
return NewTProtocolException(err)
}
// Write an I16 as a zigzag varint.
func (p *TCompactProtocol) WriteI16(value int16) error {
_, err := p.writeVarint32(p.int32ToZigzag(int32(value)))
return NewTProtocolException(err)
}
// Write an i32 as a zigzag varint.
func (p *TCompactProtocol) WriteI32(value int32) error {
_, err := p.writeVarint32(p.int32ToZigzag(value))
return NewTProtocolException(err)
}
// Write an i64 as a zigzag varint.
func (p *TCompactProtocol) WriteI64(value int64) error {
_, err := p.writeVarint64(p.int64ToZigzag(value))
return NewTProtocolException(err)
}
// Write a double to the wire as 8 bytes.
func (p *TCompactProtocol) WriteDouble(value float64) error {
buf := p.buffer[0:8]
binary.LittleEndian.PutUint64(buf, math.Float64bits(value))
_, err := p.trans.Write(buf)
return NewTProtocolException(err)
}
// Write a string to the wire with a varint size preceding.
func (p *TCompactProtocol) WriteString(value string) error {
_, e := p.writeVarint32(int32(len(value)))
if e != nil {
return NewTProtocolException(e)
}
if len(value) > 0 {
}
_, e = p.trans.WriteString(value)
return e
}
// Write a byte array, using a varint for the size.
func (p *TCompactProtocol) WriteBinary(bin []byte) error {
_, e := p.writeVarint32(int32(len(bin)))
if e != nil {
return NewTProtocolException(e)
}
if len(bin) > 0 {
_, e = p.trans.Write(bin)
return NewTProtocolException(e)
}
return nil
}
//
// Reading methods.
//
// Read a message header.
func (p *TCompactProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqId int32, err error) {
protocolId, err := p.readByteDirect()
if err != nil {
return
}
if protocolId != COMPACT_PROTOCOL_ID {
e := fmt.Errorf("Expected protocol id %02x but got %02x", COMPACT_PROTOCOL_ID, protocolId)
return "", typeId, seqId, NewTProtocolExceptionWithType(BAD_VERSION, e)
}
versionAndType, err := p.readByteDirect()
if err != nil {
return
}
version := versionAndType & COMPACT_VERSION_MASK
typeId = TMessageType((versionAndType >> COMPACT_TYPE_SHIFT_AMOUNT) & COMPACT_TYPE_BITS)
if version != COMPACT_VERSION {
e := fmt.Errorf("Expected version %02x but got %02x", COMPACT_VERSION, version)
err = NewTProtocolExceptionWithType(BAD_VERSION, e)
return
}
seqId, e := p.readVarint32()
if e != nil {
err = NewTProtocolException(e)
return
}
name, err = p.ReadString()
return
}
func (p *TCompactProtocol) ReadMessageEnd() error { return nil }
// Read a struct begin. There's nothing on the wire for this, but it is our
// opportunity to push a new struct begin marker onto the field stack.
func (p *TCompactProtocol) ReadStructBegin() (name string, err error) {
p.lastField = append(p.lastField, p.lastFieldId)
p.lastFieldId = 0
return
}
// Doesn't actually consume any wire data, just removes the last field for
// this struct from the field stack.
func (p *TCompactProtocol) ReadStructEnd() error {
// consume the last field we read off the wire.
p.lastFieldId = p.lastField[len(p.lastField)-1]
p.lastField = p.lastField[:len(p.lastField)-1]
return nil
}
// Read a field header off the wire.
func (p *TCompactProtocol) ReadFieldBegin() (name string, typeId TType, id int16, err error) {
t, err := p.readByteDirect()
if err != nil {
return
}
// if it's a stop, then we can return immediately, as the struct is over.
if (t & 0x0f) == STOP {
return "", STOP, 0, nil
}
// mask off the 4 MSB of the type header. it could contain a field id delta.
modifier := int16((t & 0xf0) >> 4)
if modifier == 0 {
// not a delta. look ahead for the zigzag varint field id.
id, err = p.ReadI16()
if err != nil {
return
}
} else {
// has a delta. add the delta to the last read field id.
id = int16(p.lastFieldId) + modifier
}
typeId, e := p.getTType(tCompactType(t & 0x0f))
if e != nil {
err = NewTProtocolException(e)
return
}
// if this happens to be a boolean field, the value is encoded in the type
if p.isBoolType(t) {
// save the boolean value in a special instance variable.
p.boolValue = (byte(t)&0x0f == COMPACT_BOOLEAN_TRUE)
p.boolValueIsNotNull = true
}
// push the new field onto the field stack so we can keep the deltas going.
p.lastFieldId = int(id)
return
}
func (p *TCompactProtocol) ReadFieldEnd() error { return nil }
// Read a map header off the wire. If the size is zero, skip reading the key
// and value type. This means that 0-length maps will yield TMaps without the
// "correct" types.
func (p *TCompactProtocol) ReadMapBegin() (keyType TType, valueType TType, size int, err error) {
size32, e := p.readVarint32()
if e != nil {
err = NewTProtocolException(e)
return
}
if size32 < 0 {
err = invalidDataLength
return
}
size = int(size32)
keyAndValueType := byte(STOP)
if size != 0 {
keyAndValueType, err = p.readByteDirect()
if err != nil {
return
}
}
keyType, _ = p.getTType(tCompactType(keyAndValueType >> 4))
valueType, _ = p.getTType(tCompactType(keyAndValueType & 0xf))
return
}
func (p *TCompactProtocol) ReadMapEnd() error { return nil }
// Read a list header off the wire. If the list size is 0-14, the size will
// be packed into the element type header. If it's a longer list, the 4 MSB
// of the element type header will be 0xF, and a varint will follow with the
// true size.
func (p *TCompactProtocol) ReadListBegin() (elemType TType, size int, err error) {
size_and_type, err := p.readByteDirect()
if err != nil {
return
}
size = int((size_and_type >> 4) & 0x0f)
if size == 15 {
size2, e := p.readVarint32()
if e != nil {
err = NewTProtocolException(e)
return
}
if size2 < 0 {
err = invalidDataLength
return
}
size = int(size2)
}
elemType, e := p.getTType(tCompactType(size_and_type))
if e != nil {
err = NewTProtocolException(e)
return
}
return
}
func (p *TCompactProtocol) ReadListEnd() error { return nil }
// Read a set header off the wire. If the set size is 0-14, the size will
// be packed into the element type header. If it's a longer set, the 4 MSB
// of the element type header will be 0xF, and a varint will follow with the
// true size.
func (p *TCompactProtocol) ReadSetBegin() (elemType TType, size int, err error) {
return p.ReadListBegin()
}
func (p *TCompactProtocol) ReadSetEnd() error { return nil }
// Read a boolean off the wire. If this is a boolean field, the value should
// already have been read during readFieldBegin, so we'll just consume the
// pre-stored value. Otherwise, read a byte.
func (p *TCompactProtocol) ReadBool() (value bool, err error) {
if p.boolValueIsNotNull {
p.boolValueIsNotNull = false
return p.boolValue, nil
}
v, err := p.readByteDirect()
return v == COMPACT_BOOLEAN_TRUE, err
}
// Read a single byte off the wire. Nothing interesting here.
func (p *TCompactProtocol) ReadByte() (int8, error) {
v, err := p.readByteDirect()
if err != nil {
return 0, NewTProtocolException(err)
}
return int8(v), err
}
// Read an i16 from the wire as a zigzag varint.
func (p *TCompactProtocol) ReadI16() (value int16, err error) {
v, err := p.ReadI32()
return int16(v), err
}
// Read an i32 from the wire as a zigzag varint.
func (p *TCompactProtocol) ReadI32() (value int32, err error) {
v, e := p.readVarint32()
if e != nil {
return 0, NewTProtocolException(e)
}
value = p.zigzagToInt32(v)
return value, nil
}
// Read an i64 from the wire as a zigzag varint.
func (p *TCompactProtocol) ReadI64() (value int64, err error) {
v, e := p.readVarint64()
if e != nil {
return 0, NewTProtocolException(e)
}
value = p.zigzagToInt64(v)
return value, nil
}
// No magic here - just read a double off the wire.
func (p *TCompactProtocol) ReadDouble() (value float64, err error) {
longBits := p.buffer[0:8]
_, e := io.ReadFull(p.trans, longBits)
if e != nil {
return 0.0, NewTProtocolException(e)
}
return math.Float64frombits(p.bytesToUint64(longBits)), nil
}
// Reads a []byte (via readBinary), and then UTF-8 decodes it.
func (p *TCompactProtocol) ReadString() (value string, err error) {
length, e := p.readVarint32()
if e != nil {
return "", NewTProtocolException(e)
}
if length < 0 {
return "", invalidDataLength
}
if length == 0 {
return "", nil
}
var buf []byte
if length <= int32(len(p.buffer)) {
buf = p.buffer[0:length]
} else {
buf = make([]byte, length)
}
_, e = io.ReadFull(p.trans, buf)
return string(buf), NewTProtocolException(e)
}
// Read a []byte from the wire.
func (p *TCompactProtocol) ReadBinary() (value []byte, err error) {
length, e := p.readVarint32()
if e != nil {
return nil, NewTProtocolException(e)
}
if length == 0 {
return []byte{}, nil
}
if length < 0 {
return nil, invalidDataLength
}
buf := make([]byte, length)
_, e = io.ReadFull(p.trans, buf)
return buf, NewTProtocolException(e)
}
func (p *TCompactProtocol) Flush(ctx context.Context) (err error) {
return NewTProtocolException(p.trans.Flush(ctx))
}
func (p *TCompactProtocol) Skip(fieldType TType) (err error) {
return SkipDefaultDepth(p, fieldType)
}
func (p *TCompactProtocol) Transport() TTransport {
return p.origTransport
}
//
// Internal writing methods
//
// Abstract method for writing the start of lists and sets. List and sets on
// the wire differ only by the type indicator.
func (p *TCompactProtocol) writeCollectionBegin(elemType TType, size int) (int, error) {
if size <= 14 {
return 1, p.writeByteDirect(byte(int32(size<<4) | int32(p.getCompactType(elemType))))
}
err := p.writeByteDirect(0xf0 | byte(p.getCompactType(elemType)))
if err != nil {
return 0, err
}
m, err := p.writeVarint32(int32(size))
return 1 + m, err
}
// Write an i32 as a varint. Results in 1-5 bytes on the wire.
// TODO(pomack): make a permanent buffer like writeVarint64?
func (p *TCompactProtocol) writeVarint32(n int32) (int, error) {
i32buf := p.buffer[0:5]
idx := 0
for {
if (n & ^0x7F) == 0 {
i32buf[idx] = byte(n)
idx++
// p.writeByteDirect(byte(n));
break
// return;
} else {
i32buf[idx] = byte((n & 0x7F) | 0x80)
idx++
// p.writeByteDirect(byte(((n & 0x7F) | 0x80)));
u := uint32(n)
n = int32(u >> 7)
}
}
return p.trans.Write(i32buf[0:idx])
}
// Write an i64 as a varint. Results in 1-10 bytes on the wire.
func (p *TCompactProtocol) writeVarint64(n int64) (int, error) {
varint64out := p.buffer[0:10]
idx := 0
for {
if (n & ^0x7F) == 0 {
varint64out[idx] = byte(n)
idx++
break
} else {
varint64out[idx] = byte((n & 0x7F) | 0x80)
idx++
u := uint64(n)
n = int64(u >> 7)
}
}
return p.trans.Write(varint64out[0:idx])
}
// Convert l into a zigzag long. This allows negative numbers to be
// represented compactly as a varint.
func (p *TCompactProtocol) int64ToZigzag(l int64) int64 {
return (l << 1) ^ (l >> 63)
}
// Convert l into a zigzag long. This allows negative numbers to be
// represented compactly as a varint.
func (p *TCompactProtocol) int32ToZigzag(n int32) int32 {
return (n << 1) ^ (n >> 31)
}
func (p *TCompactProtocol) fixedUint64ToBytes(n uint64, buf []byte) {
binary.LittleEndian.PutUint64(buf, n)
}
func (p *TCompactProtocol) fixedInt64ToBytes(n int64, buf []byte) {
binary.LittleEndian.PutUint64(buf, uint64(n))
}
// Writes a byte without any possibility of all that field header nonsense.
// Used internally by other writing methods that know they need to write a byte.
func (p *TCompactProtocol) writeByteDirect(b byte) error {
return p.trans.WriteByte(b)
}
// Writes a byte without any possibility of all that field header nonsense.
func (p *TCompactProtocol) writeIntAsByteDirect(n int) (int, error) {
return 1, p.writeByteDirect(byte(n))
}
//
// Internal reading methods
//
// Read an i32 from the wire as a varint. The MSB of each byte is set
// if there is another byte to follow. This can read up to 5 bytes.
func (p *TCompactProtocol) readVarint32() (int32, error) {
// if the wire contains the right stuff, this will just truncate the i64 we
// read and get us the right sign.
v, err := p.readVarint64()
return int32(v), err
}
// Read an i64 from the wire as a proper varint. The MSB of each byte is set
// if there is another byte to follow. This can read up to 10 bytes.
func (p *TCompactProtocol) readVarint64() (int64, error) {
shift := uint(0)
result := int64(0)
for {
b, err := p.readByteDirect()
if err != nil {
return 0, err
}
result |= int64(b&0x7f) << shift
if (b & 0x80) != 0x80 {
break
}
shift += 7
}
return result, nil
}
// Read a byte, unlike ReadByte that reads Thrift-byte that is i8.
func (p *TCompactProtocol) readByteDirect() (byte, error) {
return p.trans.ReadByte()
}
//
// encoding helpers
//
// Convert from zigzag int to int.
func (p *TCompactProtocol) zigzagToInt32(n int32) int32 {
u := uint32(n)
return int32(u>>1) ^ -(n & 1)
}
// Convert from zigzag long to long.
func (p *TCompactProtocol) zigzagToInt64(n int64) int64 {
u := uint64(n)
return int64(u>>1) ^ -(n & 1)
}
// Note that it's important that the mask bytes are long literals,
// otherwise they'll default to ints, and when you shift an int left 56 bits,
// you just get a messed up int.
func (p *TCompactProtocol) bytesToInt64(b []byte) int64 {
return int64(binary.LittleEndian.Uint64(b))
}
// Note that it's important that the mask bytes are long literals,
// otherwise they'll default to ints, and when you shift an int left 56 bits,
// you just get a messed up int.
func (p *TCompactProtocol) bytesToUint64(b []byte) uint64 {
return binary.LittleEndian.Uint64(b)
}
//
// type testing and converting
//
func (p *TCompactProtocol) isBoolType(b byte) bool {
return (b&0x0f) == COMPACT_BOOLEAN_TRUE || (b&0x0f) == COMPACT_BOOLEAN_FALSE
}
// Given a tCompactType constant, convert it to its corresponding
// TType value.
func (p *TCompactProtocol) getTType(t tCompactType) (TType, error) {
switch byte(t) & 0x0f {
case STOP:
return STOP, nil
case COMPACT_BOOLEAN_FALSE, COMPACT_BOOLEAN_TRUE:
return BOOL, nil
case COMPACT_BYTE:
return BYTE, nil
case COMPACT_I16:
return I16, nil
case COMPACT_I32:
return I32, nil
case COMPACT_I64:
return I64, nil
case COMPACT_DOUBLE:
return DOUBLE, nil
case COMPACT_BINARY:
return STRING, nil
case COMPACT_LIST:
return LIST, nil
case COMPACT_SET:
return SET, nil
case COMPACT_MAP:
return MAP, nil
case COMPACT_STRUCT:
return STRUCT, nil
}
return STOP, TException(fmt.Errorf("don't know what type: %v", t&0x0f))
}
// Given a TType value, find the appropriate TCompactProtocol.Types constant.
func (p *TCompactProtocol) getCompactType(t TType) tCompactType {
return ttypeToCompactType[t]
}

View file

@ -0,0 +1,60 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"bytes"
"testing"
)
func TestReadWriteCompactProtocol(t *testing.T) {
ReadWriteProtocolTest(t, NewTCompactProtocolFactory())
transports := []TTransport{
NewTMemoryBuffer(),
NewStreamTransportRW(bytes.NewBuffer(make([]byte, 0, 16384))),
NewTFramedTransport(NewTMemoryBuffer()),
}
zlib0, _ := NewTZlibTransport(NewTMemoryBuffer(), 0)
zlib6, _ := NewTZlibTransport(NewTMemoryBuffer(), 6)
zlib9, _ := NewTZlibTransport(NewTFramedTransport(NewTMemoryBuffer()), 9)
transports = append(transports, zlib0, zlib6, zlib9)
for _, trans := range transports {
p := NewTCompactProtocol(trans)
ReadWriteBool(t, p, trans)
p = NewTCompactProtocol(trans)
ReadWriteByte(t, p, trans)
p = NewTCompactProtocol(trans)
ReadWriteI16(t, p, trans)
p = NewTCompactProtocol(trans)
ReadWriteI32(t, p, trans)
p = NewTCompactProtocol(trans)
ReadWriteI64(t, p, trans)
p = NewTCompactProtocol(trans)
ReadWriteDouble(t, p, trans)
p = NewTCompactProtocol(trans)
ReadWriteString(t, p, trans)
p = NewTCompactProtocol(trans)
ReadWriteBinary(t, p, trans)
trans.Close()
}
}

View file

@ -0,0 +1,24 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import "context"
var defaultCtx = context.Background()

View file

@ -0,0 +1,270 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"context"
"log"
)
type TDebugProtocol struct {
Delegate TProtocol
LogPrefix string
}
type TDebugProtocolFactory struct {
Underlying TProtocolFactory
LogPrefix string
}
func NewTDebugProtocolFactory(underlying TProtocolFactory, logPrefix string) *TDebugProtocolFactory {
return &TDebugProtocolFactory{
Underlying: underlying,
LogPrefix: logPrefix,
}
}
func (t *TDebugProtocolFactory) GetProtocol(trans TTransport) TProtocol {
return &TDebugProtocol{
Delegate: t.Underlying.GetProtocol(trans),
LogPrefix: t.LogPrefix,
}
}
func (tdp *TDebugProtocol) WriteMessageBegin(name string, typeId TMessageType, seqid int32) error {
err := tdp.Delegate.WriteMessageBegin(name, typeId, seqid)
log.Printf("%sWriteMessageBegin(name=%#v, typeId=%#v, seqid=%#v) => %#v", tdp.LogPrefix, name, typeId, seqid, err)
return err
}
func (tdp *TDebugProtocol) WriteMessageEnd() error {
err := tdp.Delegate.WriteMessageEnd()
log.Printf("%sWriteMessageEnd() => %#v", tdp.LogPrefix, err)
return err
}
func (tdp *TDebugProtocol) WriteStructBegin(name string) error {
err := tdp.Delegate.WriteStructBegin(name)
log.Printf("%sWriteStructBegin(name=%#v) => %#v", tdp.LogPrefix, name, err)
return err
}
func (tdp *TDebugProtocol) WriteStructEnd() error {
err := tdp.Delegate.WriteStructEnd()
log.Printf("%sWriteStructEnd() => %#v", tdp.LogPrefix, err)
return err
}
func (tdp *TDebugProtocol) WriteFieldBegin(name string, typeId TType, id int16) error {
err := tdp.Delegate.WriteFieldBegin(name, typeId, id)
log.Printf("%sWriteFieldBegin(name=%#v, typeId=%#v, id%#v) => %#v", tdp.LogPrefix, name, typeId, id, err)
return err
}
func (tdp *TDebugProtocol) WriteFieldEnd() error {
err := tdp.Delegate.WriteFieldEnd()
log.Printf("%sWriteFieldEnd() => %#v", tdp.LogPrefix, err)
return err
}
func (tdp *TDebugProtocol) WriteFieldStop() error {
err := tdp.Delegate.WriteFieldStop()
log.Printf("%sWriteFieldStop() => %#v", tdp.LogPrefix, err)
return err
}
func (tdp *TDebugProtocol) WriteMapBegin(keyType TType, valueType TType, size int) error {
err := tdp.Delegate.WriteMapBegin(keyType, valueType, size)
log.Printf("%sWriteMapBegin(keyType=%#v, valueType=%#v, size=%#v) => %#v", tdp.LogPrefix, keyType, valueType, size, err)
return err
}
func (tdp *TDebugProtocol) WriteMapEnd() error {
err := tdp.Delegate.WriteMapEnd()
log.Printf("%sWriteMapEnd() => %#v", tdp.LogPrefix, err)
return err
}
func (tdp *TDebugProtocol) WriteListBegin(elemType TType, size int) error {
err := tdp.Delegate.WriteListBegin(elemType, size)
log.Printf("%sWriteListBegin(elemType=%#v, size=%#v) => %#v", tdp.LogPrefix, elemType, size, err)
return err
}
func (tdp *TDebugProtocol) WriteListEnd() error {
err := tdp.Delegate.WriteListEnd()
log.Printf("%sWriteListEnd() => %#v", tdp.LogPrefix, err)
return err
}
func (tdp *TDebugProtocol) WriteSetBegin(elemType TType, size int) error {
err := tdp.Delegate.WriteSetBegin(elemType, size)
log.Printf("%sWriteSetBegin(elemType=%#v, size=%#v) => %#v", tdp.LogPrefix, elemType, size, err)
return err
}
func (tdp *TDebugProtocol) WriteSetEnd() error {
err := tdp.Delegate.WriteSetEnd()
log.Printf("%sWriteSetEnd() => %#v", tdp.LogPrefix, err)
return err
}
func (tdp *TDebugProtocol) WriteBool(value bool) error {
err := tdp.Delegate.WriteBool(value)
log.Printf("%sWriteBool(value=%#v) => %#v", tdp.LogPrefix, value, err)
return err
}
func (tdp *TDebugProtocol) WriteByte(value int8) error {
err := tdp.Delegate.WriteByte(value)
log.Printf("%sWriteByte(value=%#v) => %#v", tdp.LogPrefix, value, err)
return err
}
func (tdp *TDebugProtocol) WriteI16(value int16) error {
err := tdp.Delegate.WriteI16(value)
log.Printf("%sWriteI16(value=%#v) => %#v", tdp.LogPrefix, value, err)
return err
}
func (tdp *TDebugProtocol) WriteI32(value int32) error {
err := tdp.Delegate.WriteI32(value)
log.Printf("%sWriteI32(value=%#v) => %#v", tdp.LogPrefix, value, err)
return err
}
func (tdp *TDebugProtocol) WriteI64(value int64) error {
err := tdp.Delegate.WriteI64(value)
log.Printf("%sWriteI64(value=%#v) => %#v", tdp.LogPrefix, value, err)
return err
}
func (tdp *TDebugProtocol) WriteDouble(value float64) error {
err := tdp.Delegate.WriteDouble(value)
log.Printf("%sWriteDouble(value=%#v) => %#v", tdp.LogPrefix, value, err)
return err
}
func (tdp *TDebugProtocol) WriteString(value string) error {
err := tdp.Delegate.WriteString(value)
log.Printf("%sWriteString(value=%#v) => %#v", tdp.LogPrefix, value, err)
return err
}
func (tdp *TDebugProtocol) WriteBinary(value []byte) error {
err := tdp.Delegate.WriteBinary(value)
log.Printf("%sWriteBinary(value=%#v) => %#v", tdp.LogPrefix, value, err)
return err
}
func (tdp *TDebugProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqid int32, err error) {
name, typeId, seqid, err = tdp.Delegate.ReadMessageBegin()
log.Printf("%sReadMessageBegin() (name=%#v, typeId=%#v, seqid=%#v, err=%#v)", tdp.LogPrefix, name, typeId, seqid, err)
return
}
func (tdp *TDebugProtocol) ReadMessageEnd() (err error) {
err = tdp.Delegate.ReadMessageEnd()
log.Printf("%sReadMessageEnd() err=%#v", tdp.LogPrefix, err)
return
}
func (tdp *TDebugProtocol) ReadStructBegin() (name string, err error) {
name, err = tdp.Delegate.ReadStructBegin()
log.Printf("%sReadStructBegin() (name%#v, err=%#v)", tdp.LogPrefix, name, err)
return
}
func (tdp *TDebugProtocol) ReadStructEnd() (err error) {
err = tdp.Delegate.ReadStructEnd()
log.Printf("%sReadStructEnd() err=%#v", tdp.LogPrefix, err)
return
}
func (tdp *TDebugProtocol) ReadFieldBegin() (name string, typeId TType, id int16, err error) {
name, typeId, id, err = tdp.Delegate.ReadFieldBegin()
log.Printf("%sReadFieldBegin() (name=%#v, typeId=%#v, id=%#v, err=%#v)", tdp.LogPrefix, name, typeId, id, err)
return
}
func (tdp *TDebugProtocol) ReadFieldEnd() (err error) {
err = tdp.Delegate.ReadFieldEnd()
log.Printf("%sReadFieldEnd() err=%#v", tdp.LogPrefix, err)
return
}
func (tdp *TDebugProtocol) ReadMapBegin() (keyType TType, valueType TType, size int, err error) {
keyType, valueType, size, err = tdp.Delegate.ReadMapBegin()
log.Printf("%sReadMapBegin() (keyType=%#v, valueType=%#v, size=%#v, err=%#v)", tdp.LogPrefix, keyType, valueType, size, err)
return
}
func (tdp *TDebugProtocol) ReadMapEnd() (err error) {
err = tdp.Delegate.ReadMapEnd()
log.Printf("%sReadMapEnd() err=%#v", tdp.LogPrefix, err)
return
}
func (tdp *TDebugProtocol) ReadListBegin() (elemType TType, size int, err error) {
elemType, size, err = tdp.Delegate.ReadListBegin()
log.Printf("%sReadListBegin() (elemType=%#v, size=%#v, err=%#v)", tdp.LogPrefix, elemType, size, err)
return
}
func (tdp *TDebugProtocol) ReadListEnd() (err error) {
err = tdp.Delegate.ReadListEnd()
log.Printf("%sReadListEnd() err=%#v", tdp.LogPrefix, err)
return
}
func (tdp *TDebugProtocol) ReadSetBegin() (elemType TType, size int, err error) {
elemType, size, err = tdp.Delegate.ReadSetBegin()
log.Printf("%sReadSetBegin() (elemType=%#v, size=%#v, err=%#v)", tdp.LogPrefix, elemType, size, err)
return
}
func (tdp *TDebugProtocol) ReadSetEnd() (err error) {
err = tdp.Delegate.ReadSetEnd()
log.Printf("%sReadSetEnd() err=%#v", tdp.LogPrefix, err)
return
}
func (tdp *TDebugProtocol) ReadBool() (value bool, err error) {
value, err = tdp.Delegate.ReadBool()
log.Printf("%sReadBool() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
return
}
func (tdp *TDebugProtocol) ReadByte() (value int8, err error) {
value, err = tdp.Delegate.ReadByte()
log.Printf("%sReadByte() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
return
}
func (tdp *TDebugProtocol) ReadI16() (value int16, err error) {
value, err = tdp.Delegate.ReadI16()
log.Printf("%sReadI16() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
return
}
func (tdp *TDebugProtocol) ReadI32() (value int32, err error) {
value, err = tdp.Delegate.ReadI32()
log.Printf("%sReadI32() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
return
}
func (tdp *TDebugProtocol) ReadI64() (value int64, err error) {
value, err = tdp.Delegate.ReadI64()
log.Printf("%sReadI64() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
return
}
func (tdp *TDebugProtocol) ReadDouble() (value float64, err error) {
value, err = tdp.Delegate.ReadDouble()
log.Printf("%sReadDouble() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
return
}
func (tdp *TDebugProtocol) ReadString() (value string, err error) {
value, err = tdp.Delegate.ReadString()
log.Printf("%sReadString() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
return
}
func (tdp *TDebugProtocol) ReadBinary() (value []byte, err error) {
value, err = tdp.Delegate.ReadBinary()
log.Printf("%sReadBinary() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
return
}
func (tdp *TDebugProtocol) Skip(fieldType TType) (err error) {
err = tdp.Delegate.Skip(fieldType)
log.Printf("%sSkip(fieldType=%#v) (err=%#v)", tdp.LogPrefix, fieldType, err)
return
}
func (tdp *TDebugProtocol) Flush(ctx context.Context) (err error) {
err = tdp.Delegate.Flush(ctx)
log.Printf("%sFlush() (err=%#v)", tdp.LogPrefix, err)
return
}
func (tdp *TDebugProtocol) Transport() TTransport {
return tdp.Delegate.Transport()
}

View file

@ -0,0 +1,58 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
type TDeserializer struct {
Transport TTransport
Protocol TProtocol
}
func NewTDeserializer() *TDeserializer {
var transport TTransport
transport = NewTMemoryBufferLen(1024)
protocol := NewTBinaryProtocolFactoryDefault().GetProtocol(transport)
return &TDeserializer{
transport,
protocol}
}
func (t *TDeserializer) ReadString(msg TStruct, s string) (err error) {
err = nil
if _, err = t.Transport.Write([]byte(s)); err != nil {
return
}
if err = msg.Read(t.Protocol); err != nil {
return
}
return
}
func (t *TDeserializer) Read(msg TStruct, b []byte) (err error) {
err = nil
if _, err = t.Transport.Write(b); err != nil {
return
}
if err = msg.Read(t.Protocol); err != nil {
return
}
return
}

View file

@ -0,0 +1,44 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"errors"
)
// Generic Thrift exception
type TException interface {
error
}
// Prepends additional information to an error without losing the Thrift exception interface
func PrependError(prepend string, err error) error {
if t, ok := err.(TTransportException); ok {
return NewTTransportException(t.TypeId(), prepend+t.Error())
}
if t, ok := err.(TProtocolException); ok {
return NewTProtocolExceptionWithType(t.TypeId(), errors.New(prepend+err.Error()))
}
if t, ok := err.(TApplicationException); ok {
return NewTApplicationException(t.TypeId(), prepend+t.Error())
}
return errors.New(prepend + err.Error())
}

View file

@ -0,0 +1,69 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"errors"
"testing"
)
func TestPrependError(t *testing.T) {
err := NewTApplicationException(INTERNAL_ERROR, "original error")
err2, ok := PrependError("Prepend: ", err).(TApplicationException)
if !ok {
t.Fatal("Couldn't cast error TApplicationException")
}
if err2.Error() != "Prepend: original error" {
t.Fatal("Unexpected error string")
}
if err2.TypeId() != INTERNAL_ERROR {
t.Fatal("Unexpected type error")
}
err3 := NewTProtocolExceptionWithType(INVALID_DATA, errors.New("original error"))
err4, ok := PrependError("Prepend: ", err3).(TProtocolException)
if !ok {
t.Fatal("Couldn't cast error TProtocolException")
}
if err4.Error() != "Prepend: original error" {
t.Fatal("Unexpected error string")
}
if err4.TypeId() != INVALID_DATA {
t.Fatal("Unexpected type error")
}
err5 := NewTTransportException(TIMED_OUT, "original error")
err6, ok := PrependError("Prepend: ", err5).(TTransportException)
if !ok {
t.Fatal("Couldn't cast error TTransportException")
}
if err6.Error() != "Prepend: original error" {
t.Fatal("Unexpected error string")
}
if err6.TypeId() != TIMED_OUT {
t.Fatal("Unexpected type error")
}
err7 := errors.New("original error")
err8 := PrependError("Prepend: ", err7)
if err8.Error() != "Prepend: original error" {
t.Fatal("Unexpected error string")
}
}

79
vendor/github.com/apache/thrift/lib/go/thrift/field.go generated vendored Normal file
View file

@ -0,0 +1,79 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
// Helper class that encapsulates field metadata.
type field struct {
name string
typeId TType
id int
}
func newField(n string, t TType, i int) *field {
return &field{name: n, typeId: t, id: i}
}
func (p *field) Name() string {
if p == nil {
return ""
}
return p.name
}
func (p *field) TypeId() TType {
if p == nil {
return TType(VOID)
}
return p.typeId
}
func (p *field) Id() int {
if p == nil {
return -1
}
return p.id
}
func (p *field) String() string {
if p == nil {
return "<nil>"
}
return "<TField name:'" + p.name + "' type:" + string(p.typeId) + " field-id:" + string(p.id) + ">"
}
var ANONYMOUS_FIELD *field
type fieldSlice []field
func (p fieldSlice) Len() int {
return len(p)
}
func (p fieldSlice) Less(i, j int) bool {
return p[i].Id() < p[j].Id()
}
func (p fieldSlice) Swap(i, j int) {
p[i], p[j] = p[j], p[i]
}
func init() {
ANONYMOUS_FIELD = newField("", STOP, 0)
}

View file

@ -0,0 +1,173 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"bufio"
"bytes"
"context"
"encoding/binary"
"fmt"
"io"
)
const DEFAULT_MAX_LENGTH = 16384000
type TFramedTransport struct {
transport TTransport
buf bytes.Buffer
reader *bufio.Reader
frameSize uint32 //Current remaining size of the frame. if ==0 read next frame header
buffer [4]byte
maxLength uint32
}
type tFramedTransportFactory struct {
factory TTransportFactory
maxLength uint32
}
func NewTFramedTransportFactory(factory TTransportFactory) TTransportFactory {
return &tFramedTransportFactory{factory: factory, maxLength: DEFAULT_MAX_LENGTH}
}
func NewTFramedTransportFactoryMaxLength(factory TTransportFactory, maxLength uint32) TTransportFactory {
return &tFramedTransportFactory{factory: factory, maxLength: maxLength}
}
func (p *tFramedTransportFactory) GetTransport(base TTransport) (TTransport, error) {
tt, err := p.factory.GetTransport(base)
if err != nil {
return nil, err
}
return NewTFramedTransportMaxLength(tt, p.maxLength), nil
}
func NewTFramedTransport(transport TTransport) *TFramedTransport {
return &TFramedTransport{transport: transport, reader: bufio.NewReader(transport), maxLength: DEFAULT_MAX_LENGTH}
}
func NewTFramedTransportMaxLength(transport TTransport, maxLength uint32) *TFramedTransport {
return &TFramedTransport{transport: transport, reader: bufio.NewReader(transport), maxLength: maxLength}
}
func (p *TFramedTransport) Open() error {
return p.transport.Open()
}
func (p *TFramedTransport) IsOpen() bool {
return p.transport.IsOpen()
}
func (p *TFramedTransport) Close() error {
return p.transport.Close()
}
func (p *TFramedTransport) Read(buf []byte) (l int, err error) {
if p.frameSize == 0 {
p.frameSize, err = p.readFrameHeader()
if err != nil {
return
}
}
if p.frameSize < uint32(len(buf)) {
frameSize := p.frameSize
tmp := make([]byte, p.frameSize)
l, err = p.Read(tmp)
copy(buf, tmp)
if err == nil {
err = NewTTransportExceptionFromError(fmt.Errorf("Not enough frame size %d to read %d bytes", frameSize, len(buf)))
return
}
}
got, err := p.reader.Read(buf)
p.frameSize = p.frameSize - uint32(got)
//sanity check
if p.frameSize < 0 {
return 0, NewTTransportException(UNKNOWN_TRANSPORT_EXCEPTION, "Negative frame size")
}
return got, NewTTransportExceptionFromError(err)
}
func (p *TFramedTransport) ReadByte() (c byte, err error) {
if p.frameSize == 0 {
p.frameSize, err = p.readFrameHeader()
if err != nil {
return
}
}
if p.frameSize < 1 {
return 0, NewTTransportExceptionFromError(fmt.Errorf("Not enough frame size %d to read %d bytes", p.frameSize, 1))
}
c, err = p.reader.ReadByte()
if err == nil {
p.frameSize--
}
return
}
func (p *TFramedTransport) Write(buf []byte) (int, error) {
n, err := p.buf.Write(buf)
return n, NewTTransportExceptionFromError(err)
}
func (p *TFramedTransport) WriteByte(c byte) error {
return p.buf.WriteByte(c)
}
func (p *TFramedTransport) WriteString(s string) (n int, err error) {
return p.buf.WriteString(s)
}
func (p *TFramedTransport) Flush(ctx context.Context) error {
size := p.buf.Len()
buf := p.buffer[:4]
binary.BigEndian.PutUint32(buf, uint32(size))
_, err := p.transport.Write(buf)
if err != nil {
p.buf.Truncate(0)
return NewTTransportExceptionFromError(err)
}
if size > 0 {
if n, err := p.buf.WriteTo(p.transport); err != nil {
print("Error while flushing write buffer of size ", size, " to transport, only wrote ", n, " bytes: ", err.Error(), "\n")
p.buf.Truncate(0)
return NewTTransportExceptionFromError(err)
}
}
err = p.transport.Flush(ctx)
return NewTTransportExceptionFromError(err)
}
func (p *TFramedTransport) readFrameHeader() (uint32, error) {
buf := p.buffer[:4]
if _, err := io.ReadFull(p.reader, buf); err != nil {
return 0, err
}
size := binary.BigEndian.Uint32(buf)
if size < 0 || size > p.maxLength {
return 0, NewTTransportException(UNKNOWN_TRANSPORT_EXCEPTION, fmt.Sprintf("Incorrect frame size (%d)", size))
}
return size, nil
}
func (p *TFramedTransport) RemainingBytes() (num_bytes uint64) {
return uint64(p.frameSize)
}

View file

@ -0,0 +1,29 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"testing"
)
func TestFramedTransport(t *testing.T) {
trans := NewTFramedTransport(NewTMemoryBuffer())
TransportTest(t, trans, trans)
}

View file

@ -0,0 +1,242 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"bytes"
"context"
"io"
"io/ioutil"
"net/http"
"net/url"
"strconv"
)
// Default to using the shared http client. Library users are
// free to change this global client or specify one through
// THttpClientOptions.
var DefaultHttpClient *http.Client = http.DefaultClient
type THttpClient struct {
client *http.Client
response *http.Response
url *url.URL
requestBuffer *bytes.Buffer
header http.Header
nsecConnectTimeout int64
nsecReadTimeout int64
}
type THttpClientTransportFactory struct {
options THttpClientOptions
url string
}
func (p *THttpClientTransportFactory) GetTransport(trans TTransport) (TTransport, error) {
if trans != nil {
t, ok := trans.(*THttpClient)
if ok && t.url != nil {
return NewTHttpClientWithOptions(t.url.String(), p.options)
}
}
return NewTHttpClientWithOptions(p.url, p.options)
}
type THttpClientOptions struct {
// If nil, DefaultHttpClient is used
Client *http.Client
}
func NewTHttpClientTransportFactory(url string) *THttpClientTransportFactory {
return NewTHttpClientTransportFactoryWithOptions(url, THttpClientOptions{})
}
func NewTHttpClientTransportFactoryWithOptions(url string, options THttpClientOptions) *THttpClientTransportFactory {
return &THttpClientTransportFactory{url: url, options: options}
}
func NewTHttpClientWithOptions(urlstr string, options THttpClientOptions) (TTransport, error) {
parsedURL, err := url.Parse(urlstr)
if err != nil {
return nil, err
}
buf := make([]byte, 0, 1024)
client := options.Client
if client == nil {
client = DefaultHttpClient
}
httpHeader := map[string][]string{"Content-Type": {"application/x-thrift"}}
return &THttpClient{client: client, url: parsedURL, requestBuffer: bytes.NewBuffer(buf), header: httpHeader}, nil
}
func NewTHttpClient(urlstr string) (TTransport, error) {
return NewTHttpClientWithOptions(urlstr, THttpClientOptions{})
}
// Set the HTTP Header for this specific Thrift Transport
// It is important that you first assert the TTransport as a THttpClient type
// like so:
//
// httpTrans := trans.(THttpClient)
// httpTrans.SetHeader("User-Agent","Thrift Client 1.0")
func (p *THttpClient) SetHeader(key string, value string) {
p.header.Add(key, value)
}
// Get the HTTP Header represented by the supplied Header Key for this specific Thrift Transport
// It is important that you first assert the TTransport as a THttpClient type
// like so:
//
// httpTrans := trans.(THttpClient)
// hdrValue := httpTrans.GetHeader("User-Agent")
func (p *THttpClient) GetHeader(key string) string {
return p.header.Get(key)
}
// Deletes the HTTP Header given a Header Key for this specific Thrift Transport
// It is important that you first assert the TTransport as a THttpClient type
// like so:
//
// httpTrans := trans.(THttpClient)
// httpTrans.DelHeader("User-Agent")
func (p *THttpClient) DelHeader(key string) {
p.header.Del(key)
}
func (p *THttpClient) Open() error {
// do nothing
return nil
}
func (p *THttpClient) IsOpen() bool {
return p.response != nil || p.requestBuffer != nil
}
func (p *THttpClient) closeResponse() error {
var err error
if p.response != nil && p.response.Body != nil {
// The docs specify that if keepalive is enabled and the response body is not
// read to completion the connection will never be returned to the pool and
// reused. Errors are being ignored here because if the connection is invalid
// and this fails for some reason, the Close() method will do any remaining
// cleanup.
io.Copy(ioutil.Discard, p.response.Body)
err = p.response.Body.Close()
}
p.response = nil
return err
}
func (p *THttpClient) Close() error {
if p.requestBuffer != nil {
p.requestBuffer.Reset()
p.requestBuffer = nil
}
return p.closeResponse()
}
func (p *THttpClient) Read(buf []byte) (int, error) {
if p.response == nil {
return 0, NewTTransportException(NOT_OPEN, "Response buffer is empty, no request.")
}
n, err := p.response.Body.Read(buf)
if n > 0 && (err == nil || err == io.EOF) {
return n, nil
}
return n, NewTTransportExceptionFromError(err)
}
func (p *THttpClient) ReadByte() (c byte, err error) {
return readByte(p.response.Body)
}
func (p *THttpClient) Write(buf []byte) (int, error) {
n, err := p.requestBuffer.Write(buf)
return n, err
}
func (p *THttpClient) WriteByte(c byte) error {
return p.requestBuffer.WriteByte(c)
}
func (p *THttpClient) WriteString(s string) (n int, err error) {
return p.requestBuffer.WriteString(s)
}
func (p *THttpClient) Flush(ctx context.Context) error {
// Close any previous response body to avoid leaking connections.
p.closeResponse()
req, err := http.NewRequest("POST", p.url.String(), p.requestBuffer)
if err != nil {
return NewTTransportExceptionFromError(err)
}
req.Header = p.header
if ctx != nil {
req = req.WithContext(ctx)
}
response, err := p.client.Do(req)
if err != nil {
return NewTTransportExceptionFromError(err)
}
if response.StatusCode != http.StatusOK {
// Close the response to avoid leaking file descriptors. closeResponse does
// more than just call Close(), so temporarily assign it and reuse the logic.
p.response = response
p.closeResponse()
// TODO(pomack) log bad response
return NewTTransportException(UNKNOWN_TRANSPORT_EXCEPTION, "HTTP Response code: "+strconv.Itoa(response.StatusCode))
}
p.response = response
return nil
}
func (p *THttpClient) RemainingBytes() (num_bytes uint64) {
len := p.response.ContentLength
if len >= 0 {
return uint64(len)
}
const maxSize = ^uint64(0)
return maxSize // the thruth is, we just don't know unless framed is used
}
// Deprecated: Use NewTHttpClientTransportFactory instead.
func NewTHttpPostClientTransportFactory(url string) *THttpClientTransportFactory {
return NewTHttpClientTransportFactoryWithOptions(url, THttpClientOptions{})
}
// Deprecated: Use NewTHttpClientTransportFactoryWithOptions instead.
func NewTHttpPostClientTransportFactoryWithOptions(url string, options THttpClientOptions) *THttpClientTransportFactory {
return NewTHttpClientTransportFactoryWithOptions(url, options)
}
// Deprecated: Use NewTHttpClientWithOptions instead.
func NewTHttpPostClientWithOptions(urlstr string, options THttpClientOptions) (TTransport, error) {
return NewTHttpClientWithOptions(urlstr, options)
}
// Deprecated: Use NewTHttpClient instead.
func NewTHttpPostClient(urlstr string) (TTransport, error) {
return NewTHttpClientWithOptions(urlstr, THttpClientOptions{})
}

View file

@ -0,0 +1,106 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"net/http"
"testing"
)
func TestHttpClient(t *testing.T) {
l, addr := HttpClientSetupForTest(t)
if l != nil {
defer l.Close()
}
trans, err := NewTHttpPostClient("http://" + addr.String())
if err != nil {
l.Close()
t.Fatalf("Unable to connect to %s: %s", addr.String(), err)
}
TransportTest(t, trans, trans)
}
func TestHttpClientHeaders(t *testing.T) {
l, addr := HttpClientSetupForTest(t)
if l != nil {
defer l.Close()
}
trans, err := NewTHttpPostClient("http://" + addr.String())
if err != nil {
l.Close()
t.Fatalf("Unable to connect to %s: %s", addr.String(), err)
}
TransportHeaderTest(t, trans, trans)
}
func TestHttpCustomClient(t *testing.T) {
l, addr := HttpClientSetupForTest(t)
if l != nil {
defer l.Close()
}
httpTransport := &customHttpTransport{}
trans, err := NewTHttpPostClientWithOptions("http://"+addr.String(), THttpClientOptions{
Client: &http.Client{
Transport: httpTransport,
},
})
if err != nil {
l.Close()
t.Fatalf("Unable to connect to %s: %s", addr.String(), err)
}
TransportHeaderTest(t, trans, trans)
if !httpTransport.hit {
t.Fatalf("Custom client was not used")
}
}
func TestHttpCustomClientPackageScope(t *testing.T) {
l, addr := HttpClientSetupForTest(t)
if l != nil {
defer l.Close()
}
httpTransport := &customHttpTransport{}
DefaultHttpClient = &http.Client{
Transport: httpTransport,
}
trans, err := NewTHttpPostClient("http://" + addr.String())
if err != nil {
l.Close()
t.Fatalf("Unable to connect to %s: %s", addr.String(), err)
}
TransportHeaderTest(t, trans, trans)
if !httpTransport.hit {
t.Fatalf("Custom client was not used")
}
}
type customHttpTransport struct {
hit bool
}
func (c *customHttpTransport) RoundTrip(req *http.Request) (*http.Response, error) {
c.hit = true
return http.DefaultTransport.RoundTrip(req)
}

View file

@ -0,0 +1,63 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"compress/gzip"
"io"
"net/http"
"strings"
)
// NewThriftHandlerFunc is a function that create a ready to use Apache Thrift Handler function
func NewThriftHandlerFunc(processor TProcessor,
inPfactory, outPfactory TProtocolFactory) func(w http.ResponseWriter, r *http.Request) {
return gz(func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Content-Type", "application/x-thrift")
transport := NewStreamTransport(r.Body, w)
processor.Process(r.Context(), inPfactory.GetProtocol(transport), outPfactory.GetProtocol(transport))
})
}
// gz transparently compresses the HTTP response if the client supports it.
func gz(handler http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
handler(w, r)
return
}
w.Header().Set("Content-Encoding", "gzip")
gz := gzip.NewWriter(w)
defer gz.Close()
gzw := gzipResponseWriter{Writer: gz, ResponseWriter: w}
handler(gzw, r)
}
}
type gzipResponseWriter struct {
io.Writer
http.ResponseWriter
}
func (w gzipResponseWriter) Write(b []byte) (int, error) {
return w.Writer.Write(b)
}

View file

@ -0,0 +1,214 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"bufio"
"context"
"io"
)
// StreamTransport is a Transport made of an io.Reader and/or an io.Writer
type StreamTransport struct {
io.Reader
io.Writer
isReadWriter bool
closed bool
}
type StreamTransportFactory struct {
Reader io.Reader
Writer io.Writer
isReadWriter bool
}
func (p *StreamTransportFactory) GetTransport(trans TTransport) (TTransport, error) {
if trans != nil {
t, ok := trans.(*StreamTransport)
if ok {
if t.isReadWriter {
return NewStreamTransportRW(t.Reader.(io.ReadWriter)), nil
}
if t.Reader != nil && t.Writer != nil {
return NewStreamTransport(t.Reader, t.Writer), nil
}
if t.Reader != nil && t.Writer == nil {
return NewStreamTransportR(t.Reader), nil
}
if t.Reader == nil && t.Writer != nil {
return NewStreamTransportW(t.Writer), nil
}
return &StreamTransport{}, nil
}
}
if p.isReadWriter {
return NewStreamTransportRW(p.Reader.(io.ReadWriter)), nil
}
if p.Reader != nil && p.Writer != nil {
return NewStreamTransport(p.Reader, p.Writer), nil
}
if p.Reader != nil && p.Writer == nil {
return NewStreamTransportR(p.Reader), nil
}
if p.Reader == nil && p.Writer != nil {
return NewStreamTransportW(p.Writer), nil
}
return &StreamTransport{}, nil
}
func NewStreamTransportFactory(reader io.Reader, writer io.Writer, isReadWriter bool) *StreamTransportFactory {
return &StreamTransportFactory{Reader: reader, Writer: writer, isReadWriter: isReadWriter}
}
func NewStreamTransport(r io.Reader, w io.Writer) *StreamTransport {
return &StreamTransport{Reader: bufio.NewReader(r), Writer: bufio.NewWriter(w)}
}
func NewStreamTransportR(r io.Reader) *StreamTransport {
return &StreamTransport{Reader: bufio.NewReader(r)}
}
func NewStreamTransportW(w io.Writer) *StreamTransport {
return &StreamTransport{Writer: bufio.NewWriter(w)}
}
func NewStreamTransportRW(rw io.ReadWriter) *StreamTransport {
bufrw := bufio.NewReadWriter(bufio.NewReader(rw), bufio.NewWriter(rw))
return &StreamTransport{Reader: bufrw, Writer: bufrw, isReadWriter: true}
}
func (p *StreamTransport) IsOpen() bool {
return !p.closed
}
// implicitly opened on creation, can't be reopened once closed
func (p *StreamTransport) Open() error {
if !p.closed {
return NewTTransportException(ALREADY_OPEN, "StreamTransport already open.")
} else {
return NewTTransportException(NOT_OPEN, "cannot reopen StreamTransport.")
}
}
// Closes both the input and output streams.
func (p *StreamTransport) Close() error {
if p.closed {
return NewTTransportException(NOT_OPEN, "StreamTransport already closed.")
}
p.closed = true
closedReader := false
if p.Reader != nil {
c, ok := p.Reader.(io.Closer)
if ok {
e := c.Close()
closedReader = true
if e != nil {
return e
}
}
p.Reader = nil
}
if p.Writer != nil && (!closedReader || !p.isReadWriter) {
c, ok := p.Writer.(io.Closer)
if ok {
e := c.Close()
if e != nil {
return e
}
}
p.Writer = nil
}
return nil
}
// Flushes the underlying output stream if not null.
func (p *StreamTransport) Flush(ctx context.Context) error {
if p.Writer == nil {
return NewTTransportException(NOT_OPEN, "Cannot flush null outputStream")
}
f, ok := p.Writer.(Flusher)
if ok {
err := f.Flush()
if err != nil {
return NewTTransportExceptionFromError(err)
}
}
return nil
}
func (p *StreamTransport) Read(c []byte) (n int, err error) {
n, err = p.Reader.Read(c)
if err != nil {
err = NewTTransportExceptionFromError(err)
}
return
}
func (p *StreamTransport) ReadByte() (c byte, err error) {
f, ok := p.Reader.(io.ByteReader)
if ok {
c, err = f.ReadByte()
} else {
c, err = readByte(p.Reader)
}
if err != nil {
err = NewTTransportExceptionFromError(err)
}
return
}
func (p *StreamTransport) Write(c []byte) (n int, err error) {
n, err = p.Writer.Write(c)
if err != nil {
err = NewTTransportExceptionFromError(err)
}
return
}
func (p *StreamTransport) WriteByte(c byte) (err error) {
f, ok := p.Writer.(io.ByteWriter)
if ok {
err = f.WriteByte(c)
} else {
err = writeByte(p.Writer, c)
}
if err != nil {
err = NewTTransportExceptionFromError(err)
}
return
}
func (p *StreamTransport) WriteString(s string) (n int, err error) {
f, ok := p.Writer.(stringWriter)
if ok {
n, err = f.WriteString(s)
} else {
n, err = p.Writer.Write([]byte(s))
}
if err != nil {
err = NewTTransportExceptionFromError(err)
}
return
}
func (p *StreamTransport) RemainingBytes() (num_bytes uint64) {
const maxSize = ^uint64(0)
return maxSize // the thruth is, we just don't know unless framed is used
}

View file

@ -0,0 +1,52 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"bytes"
"testing"
)
func TestStreamTransport(t *testing.T) {
trans := NewStreamTransportRW(bytes.NewBuffer(make([]byte, 0, 1024)))
TransportTest(t, trans, trans)
}
func TestStreamTransportOpenClose(t *testing.T) {
trans := NewStreamTransportRW(bytes.NewBuffer(make([]byte, 0, 1024)))
if !trans.IsOpen() {
t.Fatal("StreamTransport should be already open")
}
if trans.Open() == nil {
t.Fatal("StreamTransport should return error when open twice")
}
if trans.Close() != nil {
t.Fatal("StreamTransport should not return error when closing open transport")
}
if trans.IsOpen() {
t.Fatal("StreamTransport should not be open after close")
}
if trans.Close() == nil {
t.Fatal("StreamTransport should return error when closing a non open transport")
}
if trans.Open() == nil {
t.Fatal("StreamTransport should not be able to reopen")
}
}

View file

@ -0,0 +1,584 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"context"
"encoding/base64"
"fmt"
)
const (
THRIFT_JSON_PROTOCOL_VERSION = 1
)
// for references to _ParseContext see tsimplejson_protocol.go
// JSON protocol implementation for thrift.
//
// This protocol produces/consumes a simple output format
// suitable for parsing by scripting languages. It should not be
// confused with the full-featured TJSONProtocol.
//
type TJSONProtocol struct {
*TSimpleJSONProtocol
}
// Constructor
func NewTJSONProtocol(t TTransport) *TJSONProtocol {
v := &TJSONProtocol{TSimpleJSONProtocol: NewTSimpleJSONProtocol(t)}
v.parseContextStack = append(v.parseContextStack, int(_CONTEXT_IN_TOPLEVEL))
v.dumpContext = append(v.dumpContext, int(_CONTEXT_IN_TOPLEVEL))
return v
}
// Factory
type TJSONProtocolFactory struct{}
func (p *TJSONProtocolFactory) GetProtocol(trans TTransport) TProtocol {
return NewTJSONProtocol(trans)
}
func NewTJSONProtocolFactory() *TJSONProtocolFactory {
return &TJSONProtocolFactory{}
}
func (p *TJSONProtocol) WriteMessageBegin(name string, typeId TMessageType, seqId int32) error {
p.resetContextStack() // THRIFT-3735
if e := p.OutputListBegin(); e != nil {
return e
}
if e := p.WriteI32(THRIFT_JSON_PROTOCOL_VERSION); e != nil {
return e
}
if e := p.WriteString(name); e != nil {
return e
}
if e := p.WriteByte(int8(typeId)); e != nil {
return e
}
if e := p.WriteI32(seqId); e != nil {
return e
}
return nil
}
func (p *TJSONProtocol) WriteMessageEnd() error {
return p.OutputListEnd()
}
func (p *TJSONProtocol) WriteStructBegin(name string) error {
if e := p.OutputObjectBegin(); e != nil {
return e
}
return nil
}
func (p *TJSONProtocol) WriteStructEnd() error {
return p.OutputObjectEnd()
}
func (p *TJSONProtocol) WriteFieldBegin(name string, typeId TType, id int16) error {
if e := p.WriteI16(id); e != nil {
return e
}
if e := p.OutputObjectBegin(); e != nil {
return e
}
s, e1 := p.TypeIdToString(typeId)
if e1 != nil {
return e1
}
if e := p.WriteString(s); e != nil {
return e
}
return nil
}
func (p *TJSONProtocol) WriteFieldEnd() error {
return p.OutputObjectEnd()
}
func (p *TJSONProtocol) WriteFieldStop() error { return nil }
func (p *TJSONProtocol) WriteMapBegin(keyType TType, valueType TType, size int) error {
if e := p.OutputListBegin(); e != nil {
return e
}
s, e1 := p.TypeIdToString(keyType)
if e1 != nil {
return e1
}
if e := p.WriteString(s); e != nil {
return e
}
s, e1 = p.TypeIdToString(valueType)
if e1 != nil {
return e1
}
if e := p.WriteString(s); e != nil {
return e
}
if e := p.WriteI64(int64(size)); e != nil {
return e
}
return p.OutputObjectBegin()
}
func (p *TJSONProtocol) WriteMapEnd() error {
if e := p.OutputObjectEnd(); e != nil {
return e
}
return p.OutputListEnd()
}
func (p *TJSONProtocol) WriteListBegin(elemType TType, size int) error {
return p.OutputElemListBegin(elemType, size)
}
func (p *TJSONProtocol) WriteListEnd() error {
return p.OutputListEnd()
}
func (p *TJSONProtocol) WriteSetBegin(elemType TType, size int) error {
return p.OutputElemListBegin(elemType, size)
}
func (p *TJSONProtocol) WriteSetEnd() error {
return p.OutputListEnd()
}
func (p *TJSONProtocol) WriteBool(b bool) error {
if b {
return p.WriteI32(1)
}
return p.WriteI32(0)
}
func (p *TJSONProtocol) WriteByte(b int8) error {
return p.WriteI32(int32(b))
}
func (p *TJSONProtocol) WriteI16(v int16) error {
return p.WriteI32(int32(v))
}
func (p *TJSONProtocol) WriteI32(v int32) error {
return p.OutputI64(int64(v))
}
func (p *TJSONProtocol) WriteI64(v int64) error {
return p.OutputI64(int64(v))
}
func (p *TJSONProtocol) WriteDouble(v float64) error {
return p.OutputF64(v)
}
func (p *TJSONProtocol) WriteString(v string) error {
return p.OutputString(v)
}
func (p *TJSONProtocol) WriteBinary(v []byte) error {
// JSON library only takes in a string,
// not an arbitrary byte array, to ensure bytes are transmitted
// efficiently we must convert this into a valid JSON string
// therefore we use base64 encoding to avoid excessive escaping/quoting
if e := p.OutputPreValue(); e != nil {
return e
}
if _, e := p.write(JSON_QUOTE_BYTES); e != nil {
return NewTProtocolException(e)
}
writer := base64.NewEncoder(base64.StdEncoding, p.writer)
if _, e := writer.Write(v); e != nil {
p.writer.Reset(p.trans) // THRIFT-3735
return NewTProtocolException(e)
}
if e := writer.Close(); e != nil {
return NewTProtocolException(e)
}
if _, e := p.write(JSON_QUOTE_BYTES); e != nil {
return NewTProtocolException(e)
}
return p.OutputPostValue()
}
// Reading methods.
func (p *TJSONProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqId int32, err error) {
p.resetContextStack() // THRIFT-3735
if isNull, err := p.ParseListBegin(); isNull || err != nil {
return name, typeId, seqId, err
}
version, err := p.ReadI32()
if err != nil {
return name, typeId, seqId, err
}
if version != THRIFT_JSON_PROTOCOL_VERSION {
e := fmt.Errorf("Unknown Protocol version %d, expected version %d", version, THRIFT_JSON_PROTOCOL_VERSION)
return name, typeId, seqId, NewTProtocolExceptionWithType(INVALID_DATA, e)
}
if name, err = p.ReadString(); err != nil {
return name, typeId, seqId, err
}
bTypeId, err := p.ReadByte()
typeId = TMessageType(bTypeId)
if err != nil {
return name, typeId, seqId, err
}
if seqId, err = p.ReadI32(); err != nil {
return name, typeId, seqId, err
}
return name, typeId, seqId, nil
}
func (p *TJSONProtocol) ReadMessageEnd() error {
err := p.ParseListEnd()
return err
}
func (p *TJSONProtocol) ReadStructBegin() (name string, err error) {
_, err = p.ParseObjectStart()
return "", err
}
func (p *TJSONProtocol) ReadStructEnd() error {
return p.ParseObjectEnd()
}
func (p *TJSONProtocol) ReadFieldBegin() (string, TType, int16, error) {
b, _ := p.reader.Peek(1)
if len(b) < 1 || b[0] == JSON_RBRACE[0] || b[0] == JSON_RBRACKET[0] {
return "", STOP, -1, nil
}
fieldId, err := p.ReadI16()
if err != nil {
return "", STOP, fieldId, err
}
if _, err = p.ParseObjectStart(); err != nil {
return "", STOP, fieldId, err
}
sType, err := p.ReadString()
if err != nil {
return "", STOP, fieldId, err
}
fType, err := p.StringToTypeId(sType)
return "", fType, fieldId, err
}
func (p *TJSONProtocol) ReadFieldEnd() error {
return p.ParseObjectEnd()
}
func (p *TJSONProtocol) ReadMapBegin() (keyType TType, valueType TType, size int, e error) {
if isNull, e := p.ParseListBegin(); isNull || e != nil {
return VOID, VOID, 0, e
}
// read keyType
sKeyType, e := p.ReadString()
if e != nil {
return keyType, valueType, size, e
}
keyType, e = p.StringToTypeId(sKeyType)
if e != nil {
return keyType, valueType, size, e
}
// read valueType
sValueType, e := p.ReadString()
if e != nil {
return keyType, valueType, size, e
}
valueType, e = p.StringToTypeId(sValueType)
if e != nil {
return keyType, valueType, size, e
}
// read size
iSize, e := p.ReadI64()
if e != nil {
return keyType, valueType, size, e
}
size = int(iSize)
_, e = p.ParseObjectStart()
return keyType, valueType, size, e
}
func (p *TJSONProtocol) ReadMapEnd() error {
e := p.ParseObjectEnd()
if e != nil {
return e
}
return p.ParseListEnd()
}
func (p *TJSONProtocol) ReadListBegin() (elemType TType, size int, e error) {
return p.ParseElemListBegin()
}
func (p *TJSONProtocol) ReadListEnd() error {
return p.ParseListEnd()
}
func (p *TJSONProtocol) ReadSetBegin() (elemType TType, size int, e error) {
return p.ParseElemListBegin()
}
func (p *TJSONProtocol) ReadSetEnd() error {
return p.ParseListEnd()
}
func (p *TJSONProtocol) ReadBool() (bool, error) {
value, err := p.ReadI32()
return (value != 0), err
}
func (p *TJSONProtocol) ReadByte() (int8, error) {
v, err := p.ReadI64()
return int8(v), err
}
func (p *TJSONProtocol) ReadI16() (int16, error) {
v, err := p.ReadI64()
return int16(v), err
}
func (p *TJSONProtocol) ReadI32() (int32, error) {
v, err := p.ReadI64()
return int32(v), err
}
func (p *TJSONProtocol) ReadI64() (int64, error) {
v, _, err := p.ParseI64()
return v, err
}
func (p *TJSONProtocol) ReadDouble() (float64, error) {
v, _, err := p.ParseF64()
return v, err
}
func (p *TJSONProtocol) ReadString() (string, error) {
var v string
if err := p.ParsePreValue(); err != nil {
return v, err
}
f, _ := p.reader.Peek(1)
if len(f) > 0 && f[0] == JSON_QUOTE {
p.reader.ReadByte()
value, err := p.ParseStringBody()
v = value
if err != nil {
return v, err
}
} else if len(f) > 0 && f[0] == JSON_NULL[0] {
b := make([]byte, len(JSON_NULL))
_, err := p.reader.Read(b)
if err != nil {
return v, NewTProtocolException(err)
}
if string(b) != string(JSON_NULL) {
e := fmt.Errorf("Expected a JSON string, found unquoted data started with %s", string(b))
return v, NewTProtocolExceptionWithType(INVALID_DATA, e)
}
} else {
e := fmt.Errorf("Expected a JSON string, found unquoted data started with %s", string(f))
return v, NewTProtocolExceptionWithType(INVALID_DATA, e)
}
return v, p.ParsePostValue()
}
func (p *TJSONProtocol) ReadBinary() ([]byte, error) {
var v []byte
if err := p.ParsePreValue(); err != nil {
return nil, err
}
f, _ := p.reader.Peek(1)
if len(f) > 0 && f[0] == JSON_QUOTE {
p.reader.ReadByte()
value, err := p.ParseBase64EncodedBody()
v = value
if err != nil {
return v, err
}
} else if len(f) > 0 && f[0] == JSON_NULL[0] {
b := make([]byte, len(JSON_NULL))
_, err := p.reader.Read(b)
if err != nil {
return v, NewTProtocolException(err)
}
if string(b) != string(JSON_NULL) {
e := fmt.Errorf("Expected a JSON string, found unquoted data started with %s", string(b))
return v, NewTProtocolExceptionWithType(INVALID_DATA, e)
}
} else {
e := fmt.Errorf("Expected a JSON string, found unquoted data started with %s", string(f))
return v, NewTProtocolExceptionWithType(INVALID_DATA, e)
}
return v, p.ParsePostValue()
}
func (p *TJSONProtocol) Flush(ctx context.Context) (err error) {
err = p.writer.Flush()
if err == nil {
err = p.trans.Flush(ctx)
}
return NewTProtocolException(err)
}
func (p *TJSONProtocol) Skip(fieldType TType) (err error) {
return SkipDefaultDepth(p, fieldType)
}
func (p *TJSONProtocol) Transport() TTransport {
return p.trans
}
func (p *TJSONProtocol) OutputElemListBegin(elemType TType, size int) error {
if e := p.OutputListBegin(); e != nil {
return e
}
s, e1 := p.TypeIdToString(elemType)
if e1 != nil {
return e1
}
if e := p.WriteString(s); e != nil {
return e
}
if e := p.WriteI64(int64(size)); e != nil {
return e
}
return nil
}
func (p *TJSONProtocol) ParseElemListBegin() (elemType TType, size int, e error) {
if isNull, e := p.ParseListBegin(); isNull || e != nil {
return VOID, 0, e
}
sElemType, err := p.ReadString()
if err != nil {
return VOID, size, err
}
elemType, err = p.StringToTypeId(sElemType)
if err != nil {
return elemType, size, err
}
nSize, err2 := p.ReadI64()
size = int(nSize)
return elemType, size, err2
}
func (p *TJSONProtocol) readElemListBegin() (elemType TType, size int, e error) {
if isNull, e := p.ParseListBegin(); isNull || e != nil {
return VOID, 0, e
}
sElemType, err := p.ReadString()
if err != nil {
return VOID, size, err
}
elemType, err = p.StringToTypeId(sElemType)
if err != nil {
return elemType, size, err
}
nSize, err2 := p.ReadI64()
size = int(nSize)
return elemType, size, err2
}
func (p *TJSONProtocol) writeElemListBegin(elemType TType, size int) error {
if e := p.OutputListBegin(); e != nil {
return e
}
s, e1 := p.TypeIdToString(elemType)
if e1 != nil {
return e1
}
if e := p.OutputString(s); e != nil {
return e
}
if e := p.OutputI64(int64(size)); e != nil {
return e
}
return nil
}
func (p *TJSONProtocol) TypeIdToString(fieldType TType) (string, error) {
switch byte(fieldType) {
case BOOL:
return "tf", nil
case BYTE:
return "i8", nil
case I16:
return "i16", nil
case I32:
return "i32", nil
case I64:
return "i64", nil
case DOUBLE:
return "dbl", nil
case STRING:
return "str", nil
case STRUCT:
return "rec", nil
case MAP:
return "map", nil
case SET:
return "set", nil
case LIST:
return "lst", nil
}
e := fmt.Errorf("Unknown fieldType: %d", int(fieldType))
return "", NewTProtocolExceptionWithType(INVALID_DATA, e)
}
func (p *TJSONProtocol) StringToTypeId(fieldType string) (TType, error) {
switch fieldType {
case "tf":
return TType(BOOL), nil
case "i8":
return TType(BYTE), nil
case "i16":
return TType(I16), nil
case "i32":
return TType(I32), nil
case "i64":
return TType(I64), nil
case "dbl":
return TType(DOUBLE), nil
case "str":
return TType(STRING), nil
case "rec":
return TType(STRUCT), nil
case "map":
return TType(MAP), nil
case "set":
return TType(SET), nil
case "lst":
return TType(LIST), nil
}
e := fmt.Errorf("Unknown type identifier: %s", fieldType)
return TType(STOP), NewTProtocolExceptionWithType(INVALID_DATA, e)
}

View file

@ -0,0 +1,650 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"math"
"strconv"
"testing"
)
func TestWriteJSONProtocolBool(t *testing.T) {
thetype := "boolean"
trans := NewTMemoryBuffer()
p := NewTJSONProtocol(trans)
for _, value := range BOOL_VALUES {
if e := p.WriteBool(value); e != nil {
t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error())
}
if e := p.Flush(context.Background()); e != nil {
t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error())
}
s := trans.String()
expected := ""
if value {
expected = "1"
} else {
expected = "0"
}
if s != expected {
t.Fatalf("Bad value for %s %v: %s expected", thetype, value, s)
}
v := -1
if err := json.Unmarshal([]byte(s), &v); err != nil || (v != 0) != value {
t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v)
}
trans.Reset()
}
trans.Close()
}
func TestReadJSONProtocolBool(t *testing.T) {
thetype := "boolean"
for _, value := range BOOL_VALUES {
trans := NewTMemoryBuffer()
p := NewTJSONProtocol(trans)
if value {
trans.Write([]byte{'1'}) // not JSON_TRUE
} else {
trans.Write([]byte{'0'}) // not JSON_FALSE
}
trans.Flush(context.Background())
s := trans.String()
v, e := p.ReadBool()
if e != nil {
t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error())
}
if v != value {
t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v)
}
vv := -1
if err := json.Unmarshal([]byte(s), &vv); err != nil || (vv != 0) != value {
t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, vv)
}
trans.Reset()
trans.Close()
}
}
func TestWriteJSONProtocolByte(t *testing.T) {
thetype := "byte"
trans := NewTMemoryBuffer()
p := NewTJSONProtocol(trans)
for _, value := range BYTE_VALUES {
if e := p.WriteByte(value); e != nil {
t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error())
}
if e := p.Flush(context.Background()); e != nil {
t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error())
}
s := trans.String()
if s != fmt.Sprint(value) {
t.Fatalf("Bad value for %s %v: %s", thetype, value, s)
}
v := int8(0)
if err := json.Unmarshal([]byte(s), &v); err != nil || v != value {
t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v)
}
trans.Reset()
}
trans.Close()
}
func TestReadJSONProtocolByte(t *testing.T) {
thetype := "byte"
for _, value := range BYTE_VALUES {
trans := NewTMemoryBuffer()
p := NewTJSONProtocol(trans)
trans.WriteString(strconv.Itoa(int(value)))
trans.Flush(context.Background())
s := trans.String()
v, e := p.ReadByte()
if e != nil {
t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error())
}
if v != value {
t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v)
}
if err := json.Unmarshal([]byte(s), &v); err != nil || v != value {
t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v)
}
trans.Reset()
trans.Close()
}
}
func TestWriteJSONProtocolI16(t *testing.T) {
thetype := "int16"
trans := NewTMemoryBuffer()
p := NewTJSONProtocol(trans)
for _, value := range INT16_VALUES {
if e := p.WriteI16(value); e != nil {
t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error())
}
if e := p.Flush(context.Background()); e != nil {
t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error())
}
s := trans.String()
if s != fmt.Sprint(value) {
t.Fatalf("Bad value for %s %v: %s", thetype, value, s)
}
v := int16(0)
if err := json.Unmarshal([]byte(s), &v); err != nil || v != value {
t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v)
}
trans.Reset()
}
trans.Close()
}
func TestReadJSONProtocolI16(t *testing.T) {
thetype := "int16"
for _, value := range INT16_VALUES {
trans := NewTMemoryBuffer()
p := NewTJSONProtocol(trans)
trans.WriteString(strconv.Itoa(int(value)))
trans.Flush(context.Background())
s := trans.String()
v, e := p.ReadI16()
if e != nil {
t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error())
}
if v != value {
t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v)
}
if err := json.Unmarshal([]byte(s), &v); err != nil || v != value {
t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v)
}
trans.Reset()
trans.Close()
}
}
func TestWriteJSONProtocolI32(t *testing.T) {
thetype := "int32"
trans := NewTMemoryBuffer()
p := NewTJSONProtocol(trans)
for _, value := range INT32_VALUES {
if e := p.WriteI32(value); e != nil {
t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error())
}
if e := p.Flush(context.Background()); e != nil {
t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error())
}
s := trans.String()
if s != fmt.Sprint(value) {
t.Fatalf("Bad value for %s %v: %s", thetype, value, s)
}
v := int32(0)
if err := json.Unmarshal([]byte(s), &v); err != nil || v != value {
t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v)
}
trans.Reset()
}
trans.Close()
}
func TestReadJSONProtocolI32(t *testing.T) {
thetype := "int32"
for _, value := range INT32_VALUES {
trans := NewTMemoryBuffer()
p := NewTJSONProtocol(trans)
trans.WriteString(strconv.Itoa(int(value)))
trans.Flush(context.Background())
s := trans.String()
v, e := p.ReadI32()
if e != nil {
t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error())
}
if v != value {
t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v)
}
if err := json.Unmarshal([]byte(s), &v); err != nil || v != value {
t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v)
}
trans.Reset()
trans.Close()
}
}
func TestWriteJSONProtocolI64(t *testing.T) {
thetype := "int64"
trans := NewTMemoryBuffer()
p := NewTJSONProtocol(trans)
for _, value := range INT64_VALUES {
if e := p.WriteI64(value); e != nil {
t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error())
}
if e := p.Flush(context.Background()); e != nil {
t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error())
}
s := trans.String()
if s != fmt.Sprint(value) {
t.Fatalf("Bad value for %s %v: %s", thetype, value, s)
}
v := int64(0)
if err := json.Unmarshal([]byte(s), &v); err != nil || v != value {
t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v)
}
trans.Reset()
}
trans.Close()
}
func TestReadJSONProtocolI64(t *testing.T) {
thetype := "int64"
for _, value := range INT64_VALUES {
trans := NewTMemoryBuffer()
p := NewTJSONProtocol(trans)
trans.WriteString(strconv.FormatInt(value, 10))
trans.Flush(context.Background())
s := trans.String()
v, e := p.ReadI64()
if e != nil {
t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error())
}
if v != value {
t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v)
}
if err := json.Unmarshal([]byte(s), &v); err != nil || v != value {
t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v)
}
trans.Reset()
trans.Close()
}
}
func TestWriteJSONProtocolDouble(t *testing.T) {
thetype := "double"
trans := NewTMemoryBuffer()
p := NewTJSONProtocol(trans)
for _, value := range DOUBLE_VALUES {
if e := p.WriteDouble(value); e != nil {
t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error())
}
if e := p.Flush(context.Background()); e != nil {
t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error())
}
s := trans.String()
if math.IsInf(value, 1) {
if s != jsonQuote(JSON_INFINITY) {
t.Fatalf("Bad value for %s %v, wrote: %v, expected: %v", thetype, value, s, jsonQuote(JSON_INFINITY))
}
} else if math.IsInf(value, -1) {
if s != jsonQuote(JSON_NEGATIVE_INFINITY) {
t.Fatalf("Bad value for %s %v, wrote: %v, expected: %v", thetype, value, s, jsonQuote(JSON_NEGATIVE_INFINITY))
}
} else if math.IsNaN(value) {
if s != jsonQuote(JSON_NAN) {
t.Fatalf("Bad value for %s %v, wrote: %v, expected: %v", thetype, value, s, jsonQuote(JSON_NAN))
}
} else {
if s != fmt.Sprint(value) {
t.Fatalf("Bad value for %s %v: %s", thetype, value, s)
}
v := float64(0)
if err := json.Unmarshal([]byte(s), &v); err != nil || v != value {
t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v)
}
}
trans.Reset()
}
trans.Close()
}
func TestReadJSONProtocolDouble(t *testing.T) {
thetype := "double"
for _, value := range DOUBLE_VALUES {
trans := NewTMemoryBuffer()
p := NewTJSONProtocol(trans)
n := NewNumericFromDouble(value)
trans.WriteString(n.String())
trans.Flush(context.Background())
s := trans.String()
v, e := p.ReadDouble()
if e != nil {
t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error())
}
if math.IsInf(value, 1) {
if !math.IsInf(v, 1) {
t.Fatalf("Bad value for %s %v, wrote: %v, received: %v", thetype, value, s, v)
}
} else if math.IsInf(value, -1) {
if !math.IsInf(v, -1) {
t.Fatalf("Bad value for %s %v, wrote: %v, received: %v", thetype, value, s, v)
}
} else if math.IsNaN(value) {
if !math.IsNaN(v) {
t.Fatalf("Bad value for %s %v, wrote: %v, received: %v", thetype, value, s, v)
}
} else {
if v != value {
t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v)
}
if err := json.Unmarshal([]byte(s), &v); err != nil || v != value {
t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v)
}
}
trans.Reset()
trans.Close()
}
}
func TestWriteJSONProtocolString(t *testing.T) {
thetype := "string"
trans := NewTMemoryBuffer()
p := NewTJSONProtocol(trans)
for _, value := range STRING_VALUES {
if e := p.WriteString(value); e != nil {
t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error())
}
if e := p.Flush(context.Background()); e != nil {
t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error())
}
s := trans.String()
if s[0] != '"' || s[len(s)-1] != '"' {
t.Fatalf("Bad value for %s '%v', wrote '%v', expected: %v", thetype, value, s, fmt.Sprint("\"", value, "\""))
}
v := new(string)
if err := json.Unmarshal([]byte(s), v); err != nil || *v != value {
t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, *v)
}
trans.Reset()
}
trans.Close()
}
func TestReadJSONProtocolString(t *testing.T) {
thetype := "string"
for _, value := range STRING_VALUES {
trans := NewTMemoryBuffer()
p := NewTJSONProtocol(trans)
trans.WriteString(jsonQuote(value))
trans.Flush(context.Background())
s := trans.String()
v, e := p.ReadString()
if e != nil {
t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error())
}
if v != value {
t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v)
}
v1 := new(string)
if err := json.Unmarshal([]byte(s), v1); err != nil || *v1 != value {
t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, *v1)
}
trans.Reset()
trans.Close()
}
}
func TestWriteJSONProtocolBinary(t *testing.T) {
thetype := "binary"
value := protocol_bdata
b64value := make([]byte, base64.StdEncoding.EncodedLen(len(protocol_bdata)))
base64.StdEncoding.Encode(b64value, value)
b64String := string(b64value)
trans := NewTMemoryBuffer()
p := NewTJSONProtocol(trans)
if e := p.WriteBinary(value); e != nil {
t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error())
}
if e := p.Flush(context.Background()); e != nil {
t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error())
}
s := trans.String()
expectedString := fmt.Sprint("\"", b64String, "\"")
if s != expectedString {
t.Fatalf("Bad value for %s %v\n wrote: \"%v\"\nexpected: \"%v\"", thetype, value, s, expectedString)
}
v1, err := p.ReadBinary()
if err != nil {
t.Fatalf("Unable to read binary: %s", err.Error())
}
if len(v1) != len(value) {
t.Fatalf("Invalid value for binary\nexpected: \"%v\"\n read: \"%v\"", value, v1)
}
for k, v := range value {
if v1[k] != v {
t.Fatalf("Invalid value for binary at %v\nexpected: \"%v\"\n read: \"%v\"", k, v, v1[k])
}
}
trans.Close()
}
func TestReadJSONProtocolBinary(t *testing.T) {
thetype := "binary"
value := protocol_bdata
b64value := make([]byte, base64.StdEncoding.EncodedLen(len(protocol_bdata)))
base64.StdEncoding.Encode(b64value, value)
b64String := string(b64value)
trans := NewTMemoryBuffer()
p := NewTJSONProtocol(trans)
trans.WriteString(jsonQuote(b64String))
trans.Flush(context.Background())
s := trans.String()
v, e := p.ReadBinary()
if e != nil {
t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error())
}
if len(v) != len(value) {
t.Fatalf("Bad value for %s value length %v, wrote: %v, received length: %v", thetype, len(value), s, len(v))
}
for i := 0; i < len(v); i++ {
if v[i] != value[i] {
t.Fatalf("Bad value for %s at index %d value %v, wrote: %v, received: %v", thetype, i, value[i], s, v[i])
}
}
v1 := new(string)
if err := json.Unmarshal([]byte(s), v1); err != nil || *v1 != b64String {
t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, *v1)
}
trans.Reset()
trans.Close()
}
func TestWriteJSONProtocolList(t *testing.T) {
thetype := "list"
trans := NewTMemoryBuffer()
p := NewTJSONProtocol(trans)
p.WriteListBegin(TType(DOUBLE), len(DOUBLE_VALUES))
for _, value := range DOUBLE_VALUES {
if e := p.WriteDouble(value); e != nil {
t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error())
}
}
p.WriteListEnd()
if e := p.Flush(context.Background()); e != nil {
t.Fatalf("Unable to write %s due to error flushing: %s", thetype, e.Error())
}
str := trans.String()
str1 := new([]interface{})
err := json.Unmarshal([]byte(str), str1)
if err != nil {
t.Fatalf("Unable to decode %s, wrote: %s", thetype, str)
}
l := *str1
if len(l) < 2 {
t.Fatalf("List must be at least of length two to include metadata")
}
if l[0] != "dbl" {
t.Fatal("Invalid type for list, expected: ", STRING, ", but was: ", l[0])
}
if int(l[1].(float64)) != len(DOUBLE_VALUES) {
t.Fatal("Invalid length for list, expected: ", len(DOUBLE_VALUES), ", but was: ", l[1])
}
for k, value := range DOUBLE_VALUES {
s := l[k+2]
if math.IsInf(value, 1) {
if s.(string) != JSON_INFINITY {
t.Fatalf("Bad value for %s at index %v %v, wrote: %q, expected: %q, originally wrote: %q", thetype, k, value, s, jsonQuote(JSON_INFINITY), str)
}
} else if math.IsInf(value, 0) {
if s.(string) != JSON_NEGATIVE_INFINITY {
t.Fatalf("Bad value for %s at index %v %v, wrote: %q, expected: %q, originally wrote: %q", thetype, k, value, s, jsonQuote(JSON_NEGATIVE_INFINITY), str)
}
} else if math.IsNaN(value) {
if s.(string) != JSON_NAN {
t.Fatalf("Bad value for %s at index %v %v, wrote: %q, expected: %q, originally wrote: %q", thetype, k, value, s, jsonQuote(JSON_NAN), str)
}
} else {
if s.(float64) != value {
t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s'", thetype, value, s)
}
}
trans.Reset()
}
trans.Close()
}
func TestWriteJSONProtocolSet(t *testing.T) {
thetype := "set"
trans := NewTMemoryBuffer()
p := NewTJSONProtocol(trans)
p.WriteSetBegin(TType(DOUBLE), len(DOUBLE_VALUES))
for _, value := range DOUBLE_VALUES {
if e := p.WriteDouble(value); e != nil {
t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error())
}
}
p.WriteSetEnd()
if e := p.Flush(context.Background()); e != nil {
t.Fatalf("Unable to write %s due to error flushing: %s", thetype, e.Error())
}
str := trans.String()
str1 := new([]interface{})
err := json.Unmarshal([]byte(str), str1)
if err != nil {
t.Fatalf("Unable to decode %s, wrote: %s", thetype, str)
}
l := *str1
if len(l) < 2 {
t.Fatalf("Set must be at least of length two to include metadata")
}
if l[0] != "dbl" {
t.Fatal("Invalid type for set, expected: ", DOUBLE, ", but was: ", l[0])
}
if int(l[1].(float64)) != len(DOUBLE_VALUES) {
t.Fatal("Invalid length for set, expected: ", len(DOUBLE_VALUES), ", but was: ", l[1])
}
for k, value := range DOUBLE_VALUES {
s := l[k+2]
if math.IsInf(value, 1) {
if s.(string) != JSON_INFINITY {
t.Fatalf("Bad value for %s at index %v %v, wrote: %q, expected: %q, originally wrote: %q", thetype, k, value, s, jsonQuote(JSON_INFINITY), str)
}
} else if math.IsInf(value, 0) {
if s.(string) != JSON_NEGATIVE_INFINITY {
t.Fatalf("Bad value for %s at index %v %v, wrote: %q, expected: %q, originally wrote: %q", thetype, k, value, s, jsonQuote(JSON_NEGATIVE_INFINITY), str)
}
} else if math.IsNaN(value) {
if s.(string) != JSON_NAN {
t.Fatalf("Bad value for %s at index %v %v, wrote: %q, expected: %q, originally wrote: %q", thetype, k, value, s, jsonQuote(JSON_NAN), str)
}
} else {
if s.(float64) != value {
t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s'", thetype, value, s)
}
}
trans.Reset()
}
trans.Close()
}
func TestWriteJSONProtocolMap(t *testing.T) {
thetype := "map"
trans := NewTMemoryBuffer()
p := NewTJSONProtocol(trans)
p.WriteMapBegin(TType(I32), TType(DOUBLE), len(DOUBLE_VALUES))
for k, value := range DOUBLE_VALUES {
if e := p.WriteI32(int32(k)); e != nil {
t.Fatalf("Unable to write %s key int32 value %v due to error: %s", thetype, k, e.Error())
}
if e := p.WriteDouble(value); e != nil {
t.Fatalf("Unable to write %s value float64 value %v due to error: %s", thetype, value, e.Error())
}
}
p.WriteMapEnd()
if e := p.Flush(context.Background()); e != nil {
t.Fatalf("Unable to write %s due to error flushing: %s", thetype, e.Error())
}
str := trans.String()
if str[0] != '[' || str[len(str)-1] != ']' {
t.Fatalf("Bad value for %s, wrote: %v, in go: %v", thetype, str, DOUBLE_VALUES)
}
expectedKeyType, expectedValueType, expectedSize, err := p.ReadMapBegin()
if err != nil {
t.Fatalf("Error while reading map begin: %s", err.Error())
}
if expectedKeyType != I32 {
t.Fatal("Expected map key type ", I32, ", but was ", expectedKeyType)
}
if expectedValueType != DOUBLE {
t.Fatal("Expected map value type ", DOUBLE, ", but was ", expectedValueType)
}
if expectedSize != len(DOUBLE_VALUES) {
t.Fatal("Expected map size of ", len(DOUBLE_VALUES), ", but was ", expectedSize)
}
for k, value := range DOUBLE_VALUES {
ik, err := p.ReadI32()
if err != nil {
t.Fatalf("Bad key for %s index %v, wrote: %v, expected: %v, error: %s", thetype, k, ik, string(k), err.Error())
}
if int(ik) != k {
t.Fatalf("Bad key for %s index %v, wrote: %v, expected: %v", thetype, k, ik, k)
}
dv, err := p.ReadDouble()
if err != nil {
t.Fatalf("Bad value for %s index %v, wrote: %v, expected: %v, error: %s", thetype, k, dv, value, err.Error())
}
s := strconv.FormatFloat(dv, 'g', 10, 64)
if math.IsInf(value, 1) {
if !math.IsInf(dv, 1) {
t.Fatalf("Bad value for %s at index %v %v, wrote: %v, expected: %v", thetype, k, value, s, jsonQuote(JSON_INFINITY))
}
} else if math.IsInf(value, 0) {
if !math.IsInf(dv, 0) {
t.Fatalf("Bad value for %s at index %v %v, wrote: %v, expected: %v", thetype, k, value, s, jsonQuote(JSON_NEGATIVE_INFINITY))
}
} else if math.IsNaN(value) {
if !math.IsNaN(dv) {
t.Fatalf("Bad value for %s at index %v %v, wrote: %v, expected: %v", thetype, k, value, s, jsonQuote(JSON_NAN))
}
} else {
expected := strconv.FormatFloat(value, 'g', 10, 64)
if s != expected {
t.Fatalf("Bad value for %s at index %v %v, wrote: %v, expected %v", thetype, k, value, s, expected)
}
v := float64(0)
if err := json.Unmarshal([]byte(s), &v); err != nil || v != value {
t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v)
}
}
}
err = p.ReadMapEnd()
if err != nil {
t.Fatalf("Error while reading map end: %s", err.Error())
}
trans.Close()
}

View file

@ -0,0 +1,540 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"bytes"
"testing"
)
var binaryProtoF = NewTBinaryProtocolFactoryDefault()
var compactProtoF = NewTCompactProtocolFactory()
var buf = bytes.NewBuffer(make([]byte, 0, 1024))
var tfv = []TTransportFactory{
NewTMemoryBufferTransportFactory(1024),
NewStreamTransportFactory(buf, buf, true),
NewTFramedTransportFactory(NewTMemoryBufferTransportFactory(1024)),
}
func BenchmarkBinaryBool_0(b *testing.B) {
trans, err := tfv[0].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := binaryProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteBool(b, p, trans)
}
}
func BenchmarkBinaryByte_0(b *testing.B) {
trans, err := tfv[0].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := binaryProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteByte(b, p, trans)
}
}
func BenchmarkBinaryI16_0(b *testing.B) {
trans, err := tfv[0].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := binaryProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteI16(b, p, trans)
}
}
func BenchmarkBinaryI32_0(b *testing.B) {
trans, err := tfv[0].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := binaryProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteI32(b, p, trans)
}
}
func BenchmarkBinaryI64_0(b *testing.B) {
trans, err := tfv[0].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := binaryProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteI64(b, p, trans)
}
}
func BenchmarkBinaryDouble_0(b *testing.B) {
trans, err := tfv[0].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := binaryProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteDouble(b, p, trans)
}
}
func BenchmarkBinaryString_0(b *testing.B) {
trans, err := tfv[0].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := binaryProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteString(b, p, trans)
}
}
func BenchmarkBinaryBinary_0(b *testing.B) {
trans, err := tfv[0].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := binaryProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteBinary(b, p, trans)
}
}
func BenchmarkBinaryBool_1(b *testing.B) {
trans, err := tfv[1].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := binaryProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteBool(b, p, trans)
}
}
func BenchmarkBinaryByte_1(b *testing.B) {
trans, err := tfv[1].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := binaryProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteByte(b, p, trans)
}
}
func BenchmarkBinaryI16_1(b *testing.B) {
trans, err := tfv[1].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := binaryProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteI16(b, p, trans)
}
}
func BenchmarkBinaryI32_1(b *testing.B) {
trans, err := tfv[1].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := binaryProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteI32(b, p, trans)
}
}
func BenchmarkBinaryI64_1(b *testing.B) {
trans, err := tfv[1].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := binaryProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteI64(b, p, trans)
}
}
func BenchmarkBinaryDouble_1(b *testing.B) {
trans, err := tfv[1].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := binaryProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteDouble(b, p, trans)
}
}
func BenchmarkBinaryString_1(b *testing.B) {
trans, err := tfv[1].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := binaryProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteString(b, p, trans)
}
}
func BenchmarkBinaryBinary_1(b *testing.B) {
trans, err := tfv[1].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := binaryProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteBinary(b, p, trans)
}
}
func BenchmarkBinaryBool_2(b *testing.B) {
trans, err := tfv[2].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := binaryProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteBool(b, p, trans)
}
}
func BenchmarkBinaryByte_2(b *testing.B) {
trans, err := tfv[2].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := binaryProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteByte(b, p, trans)
}
}
func BenchmarkBinaryI16_2(b *testing.B) {
trans, err := tfv[2].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := binaryProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteI16(b, p, trans)
}
}
func BenchmarkBinaryI32_2(b *testing.B) {
trans, err := tfv[2].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := binaryProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteI32(b, p, trans)
}
}
func BenchmarkBinaryI64_2(b *testing.B) {
trans, err := tfv[2].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := binaryProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteI64(b, p, trans)
}
}
func BenchmarkBinaryDouble_2(b *testing.B) {
trans, err := tfv[2].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := binaryProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteDouble(b, p, trans)
}
}
func BenchmarkBinaryString_2(b *testing.B) {
trans, err := tfv[2].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := binaryProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteString(b, p, trans)
}
}
func BenchmarkBinaryBinary_2(b *testing.B) {
trans, err := tfv[2].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := binaryProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteBinary(b, p, trans)
}
}
func BenchmarkCompactBool_0(b *testing.B) {
trans, err := tfv[0].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := compactProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteBool(b, p, trans)
}
}
func BenchmarkCompactByte_0(b *testing.B) {
trans, err := tfv[0].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := compactProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteByte(b, p, trans)
}
}
func BenchmarkCompactI16_0(b *testing.B) {
trans, err := tfv[0].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := compactProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteI16(b, p, trans)
}
}
func BenchmarkCompactI32_0(b *testing.B) {
trans, err := tfv[0].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := compactProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteI32(b, p, trans)
}
}
func BenchmarkCompactI64_0(b *testing.B) {
trans, err := tfv[0].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := compactProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteI64(b, p, trans)
}
}
func BenchmarkCompactDouble0(b *testing.B) {
trans, err := tfv[0].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := compactProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteDouble(b, p, trans)
}
}
func BenchmarkCompactString0(b *testing.B) {
trans, err := tfv[0].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := compactProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteString(b, p, trans)
}
}
func BenchmarkCompactBinary0(b *testing.B) {
trans, err := tfv[0].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := compactProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteBinary(b, p, trans)
}
}
func BenchmarkCompactBool_1(b *testing.B) {
trans, err := tfv[1].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := compactProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteBool(b, p, trans)
}
}
func BenchmarkCompactByte_1(b *testing.B) {
trans, err := tfv[1].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := compactProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteByte(b, p, trans)
}
}
func BenchmarkCompactI16_1(b *testing.B) {
trans, err := tfv[1].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := compactProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteI16(b, p, trans)
}
}
func BenchmarkCompactI32_1(b *testing.B) {
trans, err := tfv[1].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := compactProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteI32(b, p, trans)
}
}
func BenchmarkCompactI64_1(b *testing.B) {
trans, err := tfv[1].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := compactProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteI64(b, p, trans)
}
}
func BenchmarkCompactDouble1(b *testing.B) {
trans, err := tfv[1].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := compactProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteDouble(b, p, trans)
}
}
func BenchmarkCompactString1(b *testing.B) {
trans, err := tfv[1].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := compactProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteString(b, p, trans)
}
}
func BenchmarkCompactBinary1(b *testing.B) {
trans, err := tfv[1].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := compactProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteBinary(b, p, trans)
}
}
func BenchmarkCompactBool_2(b *testing.B) {
trans, err := tfv[2].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := compactProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteBool(b, p, trans)
}
}
func BenchmarkCompactByte_2(b *testing.B) {
trans, err := tfv[2].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := compactProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteByte(b, p, trans)
}
}
func BenchmarkCompactI16_2(b *testing.B) {
trans, err := tfv[2].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := compactProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteI16(b, p, trans)
}
}
func BenchmarkCompactI32_2(b *testing.B) {
trans, err := tfv[2].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := compactProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteI32(b, p, trans)
}
}
func BenchmarkCompactI64_2(b *testing.B) {
trans, err := tfv[2].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := compactProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteI64(b, p, trans)
}
}
func BenchmarkCompactDouble2(b *testing.B) {
trans, err := tfv[2].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := compactProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteDouble(b, p, trans)
}
}
func BenchmarkCompactString2(b *testing.B) {
trans, err := tfv[2].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := compactProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteString(b, p, trans)
}
}
func BenchmarkCompactBinary2(b *testing.B) {
trans, err := tfv[2].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := compactProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteBinary(b, p, trans)
}
}

View file

@ -0,0 +1,80 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"bytes"
"context"
)
// Memory buffer-based implementation of the TTransport interface.
type TMemoryBuffer struct {
*bytes.Buffer
size int
}
type TMemoryBufferTransportFactory struct {
size int
}
func (p *TMemoryBufferTransportFactory) GetTransport(trans TTransport) (TTransport, error) {
if trans != nil {
t, ok := trans.(*TMemoryBuffer)
if ok && t.size > 0 {
return NewTMemoryBufferLen(t.size), nil
}
}
return NewTMemoryBufferLen(p.size), nil
}
func NewTMemoryBufferTransportFactory(size int) *TMemoryBufferTransportFactory {
return &TMemoryBufferTransportFactory{size: size}
}
func NewTMemoryBuffer() *TMemoryBuffer {
return &TMemoryBuffer{Buffer: &bytes.Buffer{}, size: 0}
}
func NewTMemoryBufferLen(size int) *TMemoryBuffer {
buf := make([]byte, 0, size)
return &TMemoryBuffer{Buffer: bytes.NewBuffer(buf), size: size}
}
func (p *TMemoryBuffer) IsOpen() bool {
return true
}
func (p *TMemoryBuffer) Open() error {
return nil
}
func (p *TMemoryBuffer) Close() error {
p.Buffer.Reset()
return nil
}
// Flushing a memory buffer is a no-op
func (p *TMemoryBuffer) Flush(ctx context.Context) error {
return nil
}
func (p *TMemoryBuffer) RemainingBytes() (num_bytes uint64) {
return uint64(p.Buffer.Len())
}

View file

@ -0,0 +1,29 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"testing"
)
func TestMemoryBuffer(t *testing.T) {
trans := NewTMemoryBufferLen(1024)
TransportTest(t, trans, trans)
}

View file

@ -0,0 +1,31 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
// Message type constants in the Thrift protocol.
type TMessageType int32
const (
INVALID_TMESSAGE_TYPE TMessageType = 0
CALL TMessageType = 1
REPLY TMessageType = 2
EXCEPTION TMessageType = 3
ONEWAY TMessageType = 4
)

View file

@ -0,0 +1,170 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"context"
"fmt"
"strings"
)
/*
TMultiplexedProtocol is a protocol-independent concrete decorator
that allows a Thrift client to communicate with a multiplexing Thrift server,
by prepending the service name to the function name during function calls.
NOTE: THIS IS NOT USED BY SERVERS. On the server, use TMultiplexedProcessor to handle request
from a multiplexing client.
This example uses a single socket transport to invoke two services:
socket := thrift.NewTSocketFromAddrTimeout(addr, TIMEOUT)
transport := thrift.NewTFramedTransport(socket)
protocol := thrift.NewTBinaryProtocolTransport(transport)
mp := thrift.NewTMultiplexedProtocol(protocol, "Calculator")
service := Calculator.NewCalculatorClient(mp)
mp2 := thrift.NewTMultiplexedProtocol(protocol, "WeatherReport")
service2 := WeatherReport.NewWeatherReportClient(mp2)
err := transport.Open()
if err != nil {
t.Fatal("Unable to open client socket", err)
}
fmt.Println(service.Add(2,2))
fmt.Println(service2.GetTemperature())
*/
type TMultiplexedProtocol struct {
TProtocol
serviceName string
}
const MULTIPLEXED_SEPARATOR = ":"
func NewTMultiplexedProtocol(protocol TProtocol, serviceName string) *TMultiplexedProtocol {
return &TMultiplexedProtocol{
TProtocol: protocol,
serviceName: serviceName,
}
}
func (t *TMultiplexedProtocol) WriteMessageBegin(name string, typeId TMessageType, seqid int32) error {
if typeId == CALL || typeId == ONEWAY {
return t.TProtocol.WriteMessageBegin(t.serviceName+MULTIPLEXED_SEPARATOR+name, typeId, seqid)
} else {
return t.TProtocol.WriteMessageBegin(name, typeId, seqid)
}
}
/*
TMultiplexedProcessor is a TProcessor allowing
a single TServer to provide multiple services.
To do so, you instantiate the processor and then register additional
processors with it, as shown in the following example:
var processor = thrift.NewTMultiplexedProcessor()
firstProcessor :=
processor.RegisterProcessor("FirstService", firstProcessor)
processor.registerProcessor(
"Calculator",
Calculator.NewCalculatorProcessor(&CalculatorHandler{}),
)
processor.registerProcessor(
"WeatherReport",
WeatherReport.NewWeatherReportProcessor(&WeatherReportHandler{}),
)
serverTransport, err := thrift.NewTServerSocketTimeout(addr, TIMEOUT)
if err != nil {
t.Fatal("Unable to create server socket", err)
}
server := thrift.NewTSimpleServer2(processor, serverTransport)
server.Serve();
*/
type TMultiplexedProcessor struct {
serviceProcessorMap map[string]TProcessor
DefaultProcessor TProcessor
}
func NewTMultiplexedProcessor() *TMultiplexedProcessor {
return &TMultiplexedProcessor{
serviceProcessorMap: make(map[string]TProcessor),
}
}
func (t *TMultiplexedProcessor) RegisterDefault(processor TProcessor) {
t.DefaultProcessor = processor
}
func (t *TMultiplexedProcessor) RegisterProcessor(name string, processor TProcessor) {
if t.serviceProcessorMap == nil {
t.serviceProcessorMap = make(map[string]TProcessor)
}
t.serviceProcessorMap[name] = processor
}
func (t *TMultiplexedProcessor) Process(ctx context.Context, in, out TProtocol) (bool, TException) {
name, typeId, seqid, err := in.ReadMessageBegin()
if err != nil {
return false, err
}
if typeId != CALL && typeId != ONEWAY {
return false, fmt.Errorf("Unexpected message type %v", typeId)
}
//extract the service name
v := strings.SplitN(name, MULTIPLEXED_SEPARATOR, 2)
if len(v) != 2 {
if t.DefaultProcessor != nil {
smb := NewStoredMessageProtocol(in, name, typeId, seqid)
return t.DefaultProcessor.Process(ctx, smb, out)
}
return false, fmt.Errorf("Service name not found in message name: %s. Did you forget to use a TMultiplexProtocol in your client?", name)
}
actualProcessor, ok := t.serviceProcessorMap[v[0]]
if !ok {
return false, fmt.Errorf("Service name not found: %s. Did you forget to call registerProcessor()?", v[0])
}
smb := NewStoredMessageProtocol(in, v[1], typeId, seqid)
return actualProcessor.Process(ctx, smb, out)
}
//Protocol that use stored message for ReadMessageBegin
type storedMessageProtocol struct {
TProtocol
name string
typeId TMessageType
seqid int32
}
func NewStoredMessageProtocol(protocol TProtocol, name string, typeId TMessageType, seqid int32) *storedMessageProtocol {
return &storedMessageProtocol{protocol, name, typeId, seqid}
}
func (s *storedMessageProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqid int32, err error) {
return s.name, s.typeId, s.seqid, nil
}

View file

@ -0,0 +1,164 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"math"
"strconv"
)
type Numeric interface {
Int64() int64
Int32() int32
Int16() int16
Byte() byte
Int() int
Float64() float64
Float32() float32
String() string
isNull() bool
}
type numeric struct {
iValue int64
dValue float64
sValue string
isNil bool
}
var (
INFINITY Numeric
NEGATIVE_INFINITY Numeric
NAN Numeric
ZERO Numeric
NUMERIC_NULL Numeric
)
func NewNumericFromDouble(dValue float64) Numeric {
if math.IsInf(dValue, 1) {
return INFINITY
}
if math.IsInf(dValue, -1) {
return NEGATIVE_INFINITY
}
if math.IsNaN(dValue) {
return NAN
}
iValue := int64(dValue)
sValue := strconv.FormatFloat(dValue, 'g', 10, 64)
isNil := false
return &numeric{iValue: iValue, dValue: dValue, sValue: sValue, isNil: isNil}
}
func NewNumericFromI64(iValue int64) Numeric {
dValue := float64(iValue)
sValue := string(iValue)
isNil := false
return &numeric{iValue: iValue, dValue: dValue, sValue: sValue, isNil: isNil}
}
func NewNumericFromI32(iValue int32) Numeric {
dValue := float64(iValue)
sValue := string(iValue)
isNil := false
return &numeric{iValue: int64(iValue), dValue: dValue, sValue: sValue, isNil: isNil}
}
func NewNumericFromString(sValue string) Numeric {
if sValue == INFINITY.String() {
return INFINITY
}
if sValue == NEGATIVE_INFINITY.String() {
return NEGATIVE_INFINITY
}
if sValue == NAN.String() {
return NAN
}
iValue, _ := strconv.ParseInt(sValue, 10, 64)
dValue, _ := strconv.ParseFloat(sValue, 64)
isNil := len(sValue) == 0
return &numeric{iValue: iValue, dValue: dValue, sValue: sValue, isNil: isNil}
}
func NewNumericFromJSONString(sValue string, isNull bool) Numeric {
if isNull {
return NewNullNumeric()
}
if sValue == JSON_INFINITY {
return INFINITY
}
if sValue == JSON_NEGATIVE_INFINITY {
return NEGATIVE_INFINITY
}
if sValue == JSON_NAN {
return NAN
}
iValue, _ := strconv.ParseInt(sValue, 10, 64)
dValue, _ := strconv.ParseFloat(sValue, 64)
return &numeric{iValue: iValue, dValue: dValue, sValue: sValue, isNil: isNull}
}
func NewNullNumeric() Numeric {
return &numeric{iValue: 0, dValue: 0.0, sValue: "", isNil: true}
}
func (p *numeric) Int64() int64 {
return p.iValue
}
func (p *numeric) Int32() int32 {
return int32(p.iValue)
}
func (p *numeric) Int16() int16 {
return int16(p.iValue)
}
func (p *numeric) Byte() byte {
return byte(p.iValue)
}
func (p *numeric) Int() int {
return int(p.iValue)
}
func (p *numeric) Float64() float64 {
return p.dValue
}
func (p *numeric) Float32() float32 {
return float32(p.dValue)
}
func (p *numeric) String() string {
return p.sValue
}
func (p *numeric) isNull() bool {
return p.isNil
}
func init() {
INFINITY = &numeric{iValue: 0, dValue: math.Inf(1), sValue: "Infinity", isNil: false}
NEGATIVE_INFINITY = &numeric{iValue: 0, dValue: math.Inf(-1), sValue: "-Infinity", isNil: false}
NAN = &numeric{iValue: 0, dValue: math.NaN(), sValue: "NaN", isNil: false}
ZERO = &numeric{iValue: 0, dValue: 0, sValue: "0", isNil: false}
NUMERIC_NULL = &numeric{iValue: 0, dValue: 0, sValue: "0", isNil: true}
}

View file

@ -0,0 +1,50 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
///////////////////////////////////////////////////////////////////////////////
// This file is home to helpers that convert from various base types to
// respective pointer types. This is necessary because Go does not permit
// references to constants, nor can a pointer type to base type be allocated
// and initialized in a single expression.
//
// E.g., this is not allowed:
//
// var ip *int = &5
//
// But this *is* allowed:
//
// func IntPtr(i int) *int { return &i }
// var ip *int = IntPtr(5)
//
// Since pointers to base types are commonplace as [optional] fields in
// exported thrift structs, we factor such helpers here.
///////////////////////////////////////////////////////////////////////////////
func Float32Ptr(v float32) *float32 { return &v }
func Float64Ptr(v float64) *float64 { return &v }
func IntPtr(v int) *int { return &v }
func Int32Ptr(v int32) *int32 { return &v }
func Int64Ptr(v int64) *int64 { return &v }
func StringPtr(v string) *string { return &v }
func Uint32Ptr(v uint32) *uint32 { return &v }
func Uint64Ptr(v uint64) *uint64 { return &v }
func BoolPtr(v bool) *bool { return &v }
func ByteSlicePtr(v []byte) *[]byte { return &v }

View file

@ -0,0 +1,70 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import "context"
// A processor is a generic object which operates upon an input stream and
// writes to some output stream.
type TProcessor interface {
Process(ctx context.Context, in, out TProtocol) (bool, TException)
}
type TProcessorFunction interface {
Process(ctx context.Context, seqId int32, in, out TProtocol) (bool, TException)
}
// The default processor factory just returns a singleton
// instance.
type TProcessorFactory interface {
GetProcessor(trans TTransport) TProcessor
}
type tProcessorFactory struct {
processor TProcessor
}
func NewTProcessorFactory(p TProcessor) TProcessorFactory {
return &tProcessorFactory{processor: p}
}
func (p *tProcessorFactory) GetProcessor(trans TTransport) TProcessor {
return p.processor
}
/**
* The default processor factory just returns a singleton
* instance.
*/
type TProcessorFunctionFactory interface {
GetProcessorFunction(trans TTransport) TProcessorFunction
}
type tProcessorFunctionFactory struct {
processor TProcessorFunction
}
func NewTProcessorFunctionFactory(p TProcessorFunction) TProcessorFunctionFactory {
return &tProcessorFunctionFactory{processor: p}
}
func (p *tProcessorFunctionFactory) GetProcessorFunction(trans TTransport) TProcessorFunction {
return p.processor
}

View file

@ -0,0 +1,179 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"context"
"errors"
"fmt"
)
const (
VERSION_MASK = 0xffff0000
VERSION_1 = 0x80010000
)
type TProtocol interface {
WriteMessageBegin(name string, typeId TMessageType, seqid int32) error
WriteMessageEnd() error
WriteStructBegin(name string) error
WriteStructEnd() error
WriteFieldBegin(name string, typeId TType, id int16) error
WriteFieldEnd() error
WriteFieldStop() error
WriteMapBegin(keyType TType, valueType TType, size int) error
WriteMapEnd() error
WriteListBegin(elemType TType, size int) error
WriteListEnd() error
WriteSetBegin(elemType TType, size int) error
WriteSetEnd() error
WriteBool(value bool) error
WriteByte(value int8) error
WriteI16(value int16) error
WriteI32(value int32) error
WriteI64(value int64) error
WriteDouble(value float64) error
WriteString(value string) error
WriteBinary(value []byte) error
ReadMessageBegin() (name string, typeId TMessageType, seqid int32, err error)
ReadMessageEnd() error
ReadStructBegin() (name string, err error)
ReadStructEnd() error
ReadFieldBegin() (name string, typeId TType, id int16, err error)
ReadFieldEnd() error
ReadMapBegin() (keyType TType, valueType TType, size int, err error)
ReadMapEnd() error
ReadListBegin() (elemType TType, size int, err error)
ReadListEnd() error
ReadSetBegin() (elemType TType, size int, err error)
ReadSetEnd() error
ReadBool() (value bool, err error)
ReadByte() (value int8, err error)
ReadI16() (value int16, err error)
ReadI32() (value int32, err error)
ReadI64() (value int64, err error)
ReadDouble() (value float64, err error)
ReadString() (value string, err error)
ReadBinary() (value []byte, err error)
Skip(fieldType TType) (err error)
Flush(ctx context.Context) (err error)
Transport() TTransport
}
// The maximum recursive depth the skip() function will traverse
const DEFAULT_RECURSION_DEPTH = 64
// Skips over the next data element from the provided input TProtocol object.
func SkipDefaultDepth(prot TProtocol, typeId TType) (err error) {
return Skip(prot, typeId, DEFAULT_RECURSION_DEPTH)
}
// Skips over the next data element from the provided input TProtocol object.
func Skip(self TProtocol, fieldType TType, maxDepth int) (err error) {
if maxDepth <= 0 {
return NewTProtocolExceptionWithType(DEPTH_LIMIT, errors.New("Depth limit exceeded"))
}
switch fieldType {
case STOP:
return
case BOOL:
_, err = self.ReadBool()
return
case BYTE:
_, err = self.ReadByte()
return
case I16:
_, err = self.ReadI16()
return
case I32:
_, err = self.ReadI32()
return
case I64:
_, err = self.ReadI64()
return
case DOUBLE:
_, err = self.ReadDouble()
return
case STRING:
_, err = self.ReadString()
return
case STRUCT:
if _, err = self.ReadStructBegin(); err != nil {
return err
}
for {
_, typeId, _, _ := self.ReadFieldBegin()
if typeId == STOP {
break
}
err := Skip(self, typeId, maxDepth-1)
if err != nil {
return err
}
self.ReadFieldEnd()
}
return self.ReadStructEnd()
case MAP:
keyType, valueType, size, err := self.ReadMapBegin()
if err != nil {
return err
}
for i := 0; i < size; i++ {
err := Skip(self, keyType, maxDepth-1)
if err != nil {
return err
}
self.Skip(valueType)
}
return self.ReadMapEnd()
case SET:
elemType, size, err := self.ReadSetBegin()
if err != nil {
return err
}
for i := 0; i < size; i++ {
err := Skip(self, elemType, maxDepth-1)
if err != nil {
return err
}
}
return self.ReadSetEnd()
case LIST:
elemType, size, err := self.ReadListBegin()
if err != nil {
return err
}
for i := 0; i < size; i++ {
err := Skip(self, elemType, maxDepth-1)
if err != nil {
return err
}
}
return self.ReadListEnd()
default:
return NewTProtocolExceptionWithType(INVALID_DATA, errors.New(fmt.Sprintf("Unknown data type %d", fieldType)))
}
return nil
}

View file

@ -0,0 +1,77 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"encoding/base64"
)
// Thrift Protocol exception
type TProtocolException interface {
TException
TypeId() int
}
const (
UNKNOWN_PROTOCOL_EXCEPTION = 0
INVALID_DATA = 1
NEGATIVE_SIZE = 2
SIZE_LIMIT = 3
BAD_VERSION = 4
NOT_IMPLEMENTED = 5
DEPTH_LIMIT = 6
)
type tProtocolException struct {
typeId int
message string
}
func (p *tProtocolException) TypeId() int {
return p.typeId
}
func (p *tProtocolException) String() string {
return p.message
}
func (p *tProtocolException) Error() string {
return p.message
}
func NewTProtocolException(err error) TProtocolException {
if err == nil {
return nil
}
if e, ok := err.(TProtocolException); ok {
return e
}
if _, ok := err.(base64.CorruptInputError); ok {
return &tProtocolException{INVALID_DATA, err.Error()}
}
return &tProtocolException{UNKNOWN_PROTOCOL_EXCEPTION, err.Error()}
}
func NewTProtocolExceptionWithType(errType int, err error) TProtocolException {
if err == nil {
return nil
}
return &tProtocolException{errType, err.Error()}
}

View file

@ -0,0 +1,25 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
// Factory interface for constructing protocol instances.
type TProtocolFactory interface {
GetProtocol(trans TTransport) TProtocol
}

View file

@ -0,0 +1,517 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"bytes"
"context"
"io/ioutil"
"math"
"net"
"net/http"
"testing"
)
const PROTOCOL_BINARY_DATA_SIZE = 155
var (
protocol_bdata []byte // test data for writing; same as data
BOOL_VALUES []bool
BYTE_VALUES []int8
INT16_VALUES []int16
INT32_VALUES []int32
INT64_VALUES []int64
DOUBLE_VALUES []float64
STRING_VALUES []string
)
func init() {
protocol_bdata = make([]byte, PROTOCOL_BINARY_DATA_SIZE)
for i := 0; i < PROTOCOL_BINARY_DATA_SIZE; i++ {
protocol_bdata[i] = byte((i + 'a') % 255)
}
BOOL_VALUES = []bool{false, true, false, false, true}
BYTE_VALUES = []int8{117, 0, 1, 32, 127, -128, -1}
INT16_VALUES = []int16{459, 0, 1, -1, -128, 127, 32767, -32768}
INT32_VALUES = []int32{459, 0, 1, -1, -128, 127, 32767, 2147483647, -2147483535}
INT64_VALUES = []int64{459, 0, 1, -1, -128, 127, 32767, 2147483647, -2147483535, 34359738481, -35184372088719, -9223372036854775808, 9223372036854775807}
DOUBLE_VALUES = []float64{459.3, 0.0, -1.0, 1.0, 0.5, 0.3333, 3.14159, 1.537e-38, 1.673e25, 6.02214179e23, -6.02214179e23, INFINITY.Float64(), NEGATIVE_INFINITY.Float64(), NAN.Float64()}
STRING_VALUES = []string{"", "a", "st[uf]f", "st,u:ff with spaces", "stuff\twith\nescape\\characters'...\"lots{of}fun</xml>"}
}
type HTTPEchoServer struct{}
type HTTPHeaderEchoServer struct{}
func (p *HTTPEchoServer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
buf, err := ioutil.ReadAll(req.Body)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write(buf)
} else {
w.WriteHeader(http.StatusOK)
w.Write(buf)
}
}
func (p *HTTPHeaderEchoServer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
buf, err := ioutil.ReadAll(req.Body)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write(buf)
} else {
w.WriteHeader(http.StatusOK)
w.Write(buf)
}
}
func HttpClientSetupForTest(t *testing.T) (net.Listener, net.Addr) {
addr, err := FindAvailableTCPServerPort(40000)
if err != nil {
t.Fatalf("Unable to find available tcp port addr: %s", err)
return nil, addr
}
l, err := net.Listen(addr.Network(), addr.String())
if err != nil {
t.Fatalf("Unable to setup tcp listener on %s: %s", addr.String(), err)
return l, addr
}
go http.Serve(l, &HTTPEchoServer{})
return l, addr
}
func HttpClientSetupForHeaderTest(t *testing.T) (net.Listener, net.Addr) {
addr, err := FindAvailableTCPServerPort(40000)
if err != nil {
t.Fatalf("Unable to find available tcp port addr: %s", err)
return nil, addr
}
l, err := net.Listen(addr.Network(), addr.String())
if err != nil {
t.Fatalf("Unable to setup tcp listener on %s: %s", addr.String(), err)
return l, addr
}
go http.Serve(l, &HTTPHeaderEchoServer{})
return l, addr
}
func ReadWriteProtocolTest(t *testing.T, protocolFactory TProtocolFactory) {
buf := bytes.NewBuffer(make([]byte, 0, 1024))
l, addr := HttpClientSetupForTest(t)
defer l.Close()
transports := []TTransportFactory{
NewTMemoryBufferTransportFactory(1024),
NewStreamTransportFactory(buf, buf, true),
NewTFramedTransportFactory(NewTMemoryBufferTransportFactory(1024)),
NewTZlibTransportFactoryWithFactory(0, NewTMemoryBufferTransportFactory(1024)),
NewTZlibTransportFactoryWithFactory(6, NewTMemoryBufferTransportFactory(1024)),
NewTZlibTransportFactoryWithFactory(9, NewTFramedTransportFactory(NewTMemoryBufferTransportFactory(1024))),
NewTHttpPostClientTransportFactory("http://" + addr.String()),
}
for _, tf := range transports {
trans, err := tf.GetTransport(nil)
if err != nil {
t.Error(err)
continue
}
p := protocolFactory.GetProtocol(trans)
ReadWriteBool(t, p, trans)
trans.Close()
}
for _, tf := range transports {
trans, err := tf.GetTransport(nil)
if err != nil {
t.Error(err)
continue
}
p := protocolFactory.GetProtocol(trans)
ReadWriteByte(t, p, trans)
trans.Close()
}
for _, tf := range transports {
trans, err := tf.GetTransport(nil)
if err != nil {
t.Error(err)
continue
}
p := protocolFactory.GetProtocol(trans)
ReadWriteI16(t, p, trans)
trans.Close()
}
for _, tf := range transports {
trans, err := tf.GetTransport(nil)
if err != nil {
t.Error(err)
continue
}
p := protocolFactory.GetProtocol(trans)
ReadWriteI32(t, p, trans)
trans.Close()
}
for _, tf := range transports {
trans, err := tf.GetTransport(nil)
if err != nil {
t.Error(err)
continue
}
p := protocolFactory.GetProtocol(trans)
ReadWriteI64(t, p, trans)
trans.Close()
}
for _, tf := range transports {
trans, err := tf.GetTransport(nil)
if err != nil {
t.Error(err)
continue
}
p := protocolFactory.GetProtocol(trans)
ReadWriteDouble(t, p, trans)
trans.Close()
}
for _, tf := range transports {
trans, err := tf.GetTransport(nil)
if err != nil {
t.Error(err)
continue
}
p := protocolFactory.GetProtocol(trans)
ReadWriteString(t, p, trans)
trans.Close()
}
for _, tf := range transports {
trans, err := tf.GetTransport(nil)
if err != nil {
t.Error(err)
continue
}
p := protocolFactory.GetProtocol(trans)
ReadWriteBinary(t, p, trans)
trans.Close()
}
for _, tf := range transports {
trans, err := tf.GetTransport(nil)
if err != nil {
t.Error(err)
continue
}
p := protocolFactory.GetProtocol(trans)
ReadWriteI64(t, p, trans)
ReadWriteDouble(t, p, trans)
ReadWriteBinary(t, p, trans)
ReadWriteByte(t, p, trans)
trans.Close()
}
}
func ReadWriteBool(t testing.TB, p TProtocol, trans TTransport) {
thetype := TType(BOOL)
thelen := len(BOOL_VALUES)
err := p.WriteListBegin(thetype, thelen)
if err != nil {
t.Errorf("%s: %T %T %q Error writing list begin: %q", "ReadWriteBool", p, trans, err, thetype)
}
for k, v := range BOOL_VALUES {
err = p.WriteBool(v)
if err != nil {
t.Errorf("%s: %T %T %v Error writing bool in list at index %v: %v", "ReadWriteBool", p, trans, err, k, v)
}
}
p.WriteListEnd()
if err != nil {
t.Errorf("%s: %T %T %v Error writing list end: %v", "ReadWriteBool", p, trans, err, BOOL_VALUES)
}
p.Flush(context.Background())
thetype2, thelen2, err := p.ReadListBegin()
if err != nil {
t.Errorf("%s: %T %T %v Error reading list: %v", "ReadWriteBool", p, trans, err, BOOL_VALUES)
}
_, ok := p.(*TSimpleJSONProtocol)
if !ok {
if thetype != thetype2 {
t.Errorf("%s: %T %T type %s != type %s", "ReadWriteBool", p, trans, thetype, thetype2)
}
if thelen != thelen2 {
t.Errorf("%s: %T %T len %v != len %v", "ReadWriteBool", p, trans, thelen, thelen2)
}
}
for k, v := range BOOL_VALUES {
value, err := p.ReadBool()
if err != nil {
t.Errorf("%s: %T %T %v Error reading bool at index %v: %v", "ReadWriteBool", p, trans, err, k, v)
}
if v != value {
t.Errorf("%s: index %v %v %v %v != %v", "ReadWriteBool", k, p, trans, v, value)
}
}
err = p.ReadListEnd()
if err != nil {
t.Errorf("%s: %T %T Unable to read list end: %q", "ReadWriteBool", p, trans, err)
}
}
func ReadWriteByte(t testing.TB, p TProtocol, trans TTransport) {
thetype := TType(BYTE)
thelen := len(BYTE_VALUES)
err := p.WriteListBegin(thetype, thelen)
if err != nil {
t.Errorf("%s: %T %T %q Error writing list begin: %q", "ReadWriteByte", p, trans, err, thetype)
}
for k, v := range BYTE_VALUES {
err = p.WriteByte(v)
if err != nil {
t.Errorf("%s: %T %T %q Error writing byte in list at index %d: %q", "ReadWriteByte", p, trans, err, k, v)
}
}
err = p.WriteListEnd()
if err != nil {
t.Errorf("%s: %T %T %q Error writing list end: %q", "ReadWriteByte", p, trans, err, BYTE_VALUES)
}
err = p.Flush(context.Background())
if err != nil {
t.Errorf("%s: %T %T %q Error flushing list of bytes: %q", "ReadWriteByte", p, trans, err, BYTE_VALUES)
}
thetype2, thelen2, err := p.ReadListBegin()
if err != nil {
t.Errorf("%s: %T %T %q Error reading list: %q", "ReadWriteByte", p, trans, err, BYTE_VALUES)
}
_, ok := p.(*TSimpleJSONProtocol)
if !ok {
if thetype != thetype2 {
t.Errorf("%s: %T %T type %s != type %s", "ReadWriteByte", p, trans, thetype, thetype2)
}
if thelen != thelen2 {
t.Errorf("%s: %T %T len %v != len %v", "ReadWriteByte", p, trans, thelen, thelen2)
}
}
for k, v := range BYTE_VALUES {
value, err := p.ReadByte()
if err != nil {
t.Errorf("%s: %T %T %q Error reading byte at index %d: %q", "ReadWriteByte", p, trans, err, k, v)
}
if v != value {
t.Errorf("%s: %T %T %d != %d", "ReadWriteByte", p, trans, v, value)
}
}
err = p.ReadListEnd()
if err != nil {
t.Errorf("%s: %T %T Unable to read list end: %q", "ReadWriteByte", p, trans, err)
}
}
func ReadWriteI16(t testing.TB, p TProtocol, trans TTransport) {
thetype := TType(I16)
thelen := len(INT16_VALUES)
p.WriteListBegin(thetype, thelen)
for _, v := range INT16_VALUES {
p.WriteI16(v)
}
p.WriteListEnd()
p.Flush(context.Background())
thetype2, thelen2, err := p.ReadListBegin()
if err != nil {
t.Errorf("%s: %T %T %q Error reading list: %q", "ReadWriteI16", p, trans, err, INT16_VALUES)
}
_, ok := p.(*TSimpleJSONProtocol)
if !ok {
if thetype != thetype2 {
t.Errorf("%s: %T %T type %s != type %s", "ReadWriteI16", p, trans, thetype, thetype2)
}
if thelen != thelen2 {
t.Errorf("%s: %T %T len %v != len %v", "ReadWriteI16", p, trans, thelen, thelen2)
}
}
for k, v := range INT16_VALUES {
value, err := p.ReadI16()
if err != nil {
t.Errorf("%s: %T %T %q Error reading int16 at index %d: %q", "ReadWriteI16", p, trans, err, k, v)
}
if v != value {
t.Errorf("%s: %T %T %d != %d", "ReadWriteI16", p, trans, v, value)
}
}
err = p.ReadListEnd()
if err != nil {
t.Errorf("%s: %T %T Unable to read list end: %q", "ReadWriteI16", p, trans, err)
}
}
func ReadWriteI32(t testing.TB, p TProtocol, trans TTransport) {
thetype := TType(I32)
thelen := len(INT32_VALUES)
p.WriteListBegin(thetype, thelen)
for _, v := range INT32_VALUES {
p.WriteI32(v)
}
p.WriteListEnd()
p.Flush(context.Background())
thetype2, thelen2, err := p.ReadListBegin()
if err != nil {
t.Errorf("%s: %T %T %q Error reading list: %q", "ReadWriteI32", p, trans, err, INT32_VALUES)
}
_, ok := p.(*TSimpleJSONProtocol)
if !ok {
if thetype != thetype2 {
t.Errorf("%s: %T %T type %s != type %s", "ReadWriteI32", p, trans, thetype, thetype2)
}
if thelen != thelen2 {
t.Errorf("%s: %T %T len %v != len %v", "ReadWriteI32", p, trans, thelen, thelen2)
}
}
for k, v := range INT32_VALUES {
value, err := p.ReadI32()
if err != nil {
t.Errorf("%s: %T %T %q Error reading int32 at index %d: %q", "ReadWriteI32", p, trans, err, k, v)
}
if v != value {
t.Errorf("%s: %T %T %d != %d", "ReadWriteI32", p, trans, v, value)
}
}
if err != nil {
t.Errorf("%s: %T %T Unable to read list end: %q", "ReadWriteI32", p, trans, err)
}
}
func ReadWriteI64(t testing.TB, p TProtocol, trans TTransport) {
thetype := TType(I64)
thelen := len(INT64_VALUES)
p.WriteListBegin(thetype, thelen)
for _, v := range INT64_VALUES {
p.WriteI64(v)
}
p.WriteListEnd()
p.Flush(context.Background())
thetype2, thelen2, err := p.ReadListBegin()
if err != nil {
t.Errorf("%s: %T %T %q Error reading list: %q", "ReadWriteI64", p, trans, err, INT64_VALUES)
}
_, ok := p.(*TSimpleJSONProtocol)
if !ok {
if thetype != thetype2 {
t.Errorf("%s: %T %T type %s != type %s", "ReadWriteI64", p, trans, thetype, thetype2)
}
if thelen != thelen2 {
t.Errorf("%s: %T %T len %v != len %v", "ReadWriteI64", p, trans, thelen, thelen2)
}
}
for k, v := range INT64_VALUES {
value, err := p.ReadI64()
if err != nil {
t.Errorf("%s: %T %T %q Error reading int64 at index %d: %q", "ReadWriteI64", p, trans, err, k, v)
}
if v != value {
t.Errorf("%s: %T %T %q != %q", "ReadWriteI64", p, trans, v, value)
}
}
if err != nil {
t.Errorf("%s: %T %T Unable to read list end: %q", "ReadWriteI64", p, trans, err)
}
}
func ReadWriteDouble(t testing.TB, p TProtocol, trans TTransport) {
thetype := TType(DOUBLE)
thelen := len(DOUBLE_VALUES)
p.WriteListBegin(thetype, thelen)
for _, v := range DOUBLE_VALUES {
p.WriteDouble(v)
}
p.WriteListEnd()
p.Flush(context.Background())
thetype2, thelen2, err := p.ReadListBegin()
if err != nil {
t.Errorf("%s: %T %T %v Error reading list: %v", "ReadWriteDouble", p, trans, err, DOUBLE_VALUES)
}
if thetype != thetype2 {
t.Errorf("%s: %T %T type %s != type %s", "ReadWriteDouble", p, trans, thetype, thetype2)
}
if thelen != thelen2 {
t.Errorf("%s: %T %T len %v != len %v", "ReadWriteDouble", p, trans, thelen, thelen2)
}
for k, v := range DOUBLE_VALUES {
value, err := p.ReadDouble()
if err != nil {
t.Errorf("%s: %T %T %q Error reading double at index %d: %v", "ReadWriteDouble", p, trans, err, k, v)
}
if math.IsNaN(v) {
if !math.IsNaN(value) {
t.Errorf("%s: %T %T math.IsNaN(%v) != math.IsNaN(%v)", "ReadWriteDouble", p, trans, v, value)
}
} else if v != value {
t.Errorf("%s: %T %T %v != %v", "ReadWriteDouble", p, trans, v, value)
}
}
err = p.ReadListEnd()
if err != nil {
t.Errorf("%s: %T %T Unable to read list end: %q", "ReadWriteDouble", p, trans, err)
}
}
func ReadWriteString(t testing.TB, p TProtocol, trans TTransport) {
thetype := TType(STRING)
thelen := len(STRING_VALUES)
p.WriteListBegin(thetype, thelen)
for _, v := range STRING_VALUES {
p.WriteString(v)
}
p.WriteListEnd()
p.Flush(context.Background())
thetype2, thelen2, err := p.ReadListBegin()
if err != nil {
t.Errorf("%s: %T %T %q Error reading list: %q", "ReadWriteString", p, trans, err, STRING_VALUES)
}
_, ok := p.(*TSimpleJSONProtocol)
if !ok {
if thetype != thetype2 {
t.Errorf("%s: %T %T type %s != type %s", "ReadWriteString", p, trans, thetype, thetype2)
}
if thelen != thelen2 {
t.Errorf("%s: %T %T len %v != len %v", "ReadWriteString", p, trans, thelen, thelen2)
}
}
for k, v := range STRING_VALUES {
value, err := p.ReadString()
if err != nil {
t.Errorf("%s: %T %T %q Error reading string at index %d: %q", "ReadWriteString", p, trans, err, k, v)
}
if v != value {
t.Errorf("%s: %T %T %v != %v", "ReadWriteString", p, trans, v, value)
}
}
if err != nil {
t.Errorf("%s: %T %T Unable to read list end: %q", "ReadWriteString", p, trans, err)
}
}
func ReadWriteBinary(t testing.TB, p TProtocol, trans TTransport) {
v := protocol_bdata
p.WriteBinary(v)
p.Flush(context.Background())
value, err := p.ReadBinary()
if err != nil {
t.Errorf("%s: %T %T Unable to read binary: %s", "ReadWriteBinary", p, trans, err.Error())
}
if len(v) != len(value) {
t.Errorf("%s: %T %T len(v) != len(value)... %d != %d", "ReadWriteBinary", p, trans, len(v), len(value))
} else {
for i := 0; i < len(v); i++ {
if v[i] != value[i] {
t.Errorf("%s: %T %T %s != %s", "ReadWriteBinary", p, trans, v, value)
}
}
}
}

View file

@ -0,0 +1,68 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import "io"
type RichTransport struct {
TTransport
}
// Wraps Transport to provide TRichTransport interface
func NewTRichTransport(trans TTransport) *RichTransport {
return &RichTransport{trans}
}
func (r *RichTransport) ReadByte() (c byte, err error) {
return readByte(r.TTransport)
}
func (r *RichTransport) WriteByte(c byte) error {
return writeByte(r.TTransport, c)
}
func (r *RichTransport) WriteString(s string) (n int, err error) {
return r.Write([]byte(s))
}
func (r *RichTransport) RemainingBytes() (num_bytes uint64) {
return r.TTransport.RemainingBytes()
}
func readByte(r io.Reader) (c byte, err error) {
v := [1]byte{0}
n, err := r.Read(v[0:1])
if n > 0 && (err == nil || err == io.EOF) {
return v[0], nil
}
if n > 0 && err != nil {
return v[0], err
}
if err != nil {
return 0, err
}
return v[0], nil
}
func writeByte(w io.Writer, c byte) error {
v := [1]byte{c}
_, err := w.Write(v[0:1])
return err
}

View file

@ -0,0 +1,89 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"bytes"
"errors"
"io"
"reflect"
"testing"
)
func TestEnsureTransportsAreRich(t *testing.T) {
buf := bytes.NewBuffer(make([]byte, 0, 1024))
transports := []TTransportFactory{
NewTMemoryBufferTransportFactory(1024),
NewStreamTransportFactory(buf, buf, true),
NewTFramedTransportFactory(NewTMemoryBufferTransportFactory(1024)),
NewTHttpPostClientTransportFactory("http://127.0.0.1"),
}
for _, tf := range transports {
trans, err := tf.GetTransport(nil)
if err != nil {
t.Error(err)
continue
}
_, ok := trans.(TRichTransport)
if !ok {
t.Errorf("Transport %s does not implement TRichTransport interface", reflect.ValueOf(trans))
}
}
}
// TestReadByte tests whether readByte handles error cases correctly.
func TestReadByte(t *testing.T) {
for i, test := range readByteTests {
v, err := readByte(test.r)
if v != test.v {
t.Fatalf("TestReadByte %d: value differs. Expected %d, got %d", i, test.v, test.r.v)
}
if err != test.err {
t.Fatalf("TestReadByte %d: error differs. Expected %s, got %s", i, test.err, test.r.err)
}
}
}
var someError = errors.New("Some error")
var readByteTests = []struct {
r *mockReader
v byte
err error
}{
{&mockReader{0, 55, io.EOF}, 0, io.EOF}, // reader sends EOF w/o data
{&mockReader{0, 55, someError}, 0, someError}, // reader sends some other error
{&mockReader{1, 55, nil}, 55, nil}, // reader sends data w/o error
{&mockReader{1, 55, io.EOF}, 55, nil}, // reader sends data with EOF
{&mockReader{1, 55, someError}, 55, someError}, // reader sends data withsome error
}
type mockReader struct {
n int
v byte
err error
}
func (r *mockReader) Read(p []byte) (n int, err error) {
if r.n > 0 {
p[0] = r.v
}
return r.n, r.err
}

View file

@ -0,0 +1,79 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"context"
)
type TSerializer struct {
Transport *TMemoryBuffer
Protocol TProtocol
}
type TStruct interface {
Write(p TProtocol) error
Read(p TProtocol) error
}
func NewTSerializer() *TSerializer {
transport := NewTMemoryBufferLen(1024)
protocol := NewTBinaryProtocolFactoryDefault().GetProtocol(transport)
return &TSerializer{
transport,
protocol}
}
func (t *TSerializer) WriteString(ctx context.Context, msg TStruct) (s string, err error) {
t.Transport.Reset()
if err = msg.Write(t.Protocol); err != nil {
return
}
if err = t.Protocol.Flush(ctx); err != nil {
return
}
if err = t.Transport.Flush(ctx); err != nil {
return
}
return t.Transport.String(), nil
}
func (t *TSerializer) Write(ctx context.Context, msg TStruct) (b []byte, err error) {
t.Transport.Reset()
if err = msg.Write(t.Protocol); err != nil {
return
}
if err = t.Protocol.Flush(ctx); err != nil {
return
}
if err = t.Transport.Flush(ctx); err != nil {
return
}
b = append(b, t.Transport.Bytes()...)
return
}

View file

@ -0,0 +1,170 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"context"
"errors"
"fmt"
"testing"
)
type ProtocolFactory interface {
GetProtocol(t TTransport) TProtocol
}
func compareStructs(m, m1 MyTestStruct) (bool, error) {
switch {
case m.On != m1.On:
return false, errors.New("Boolean not equal")
case m.B != m1.B:
return false, errors.New("Byte not equal")
case m.Int16 != m1.Int16:
return false, errors.New("Int16 not equal")
case m.Int32 != m1.Int32:
return false, errors.New("Int32 not equal")
case m.Int64 != m1.Int64:
return false, errors.New("Int64 not equal")
case m.D != m1.D:
return false, errors.New("Double not equal")
case m.St != m1.St:
return false, errors.New("String not equal")
case len(m.Bin) != len(m1.Bin):
return false, errors.New("Binary size not equal")
case len(m.Bin) == len(m1.Bin):
for i := range m.Bin {
if m.Bin[i] != m1.Bin[i] {
return false, errors.New("Binary not equal")
}
}
case len(m.StringMap) != len(m1.StringMap):
return false, errors.New("StringMap size not equal")
case len(m.StringList) != len(m1.StringList):
return false, errors.New("StringList size not equal")
case len(m.StringSet) != len(m1.StringSet):
return false, errors.New("StringSet size not equal")
case m.E != m1.E:
return false, errors.New("MyTestEnum not equal")
default:
return true, nil
}
return true, nil
}
func ProtocolTest1(test *testing.T, pf ProtocolFactory) (bool, error) {
t := NewTSerializer()
t.Protocol = pf.GetProtocol(t.Transport)
var m = MyTestStruct{}
m.On = true
m.B = int8(0)
m.Int16 = 1
m.Int32 = 2
m.Int64 = 3
m.D = 4.1
m.St = "Test"
m.Bin = make([]byte, 10)
m.StringMap = make(map[string]string, 5)
m.StringList = make([]string, 5)
m.StringSet = make(map[string]struct{}, 5)
m.E = 2
s, err := t.WriteString(context.Background(), &m)
if err != nil {
return false, errors.New(fmt.Sprintf("Unable to Serialize struct\n\t %s", err))
}
t1 := NewTDeserializer()
t1.Protocol = pf.GetProtocol(t1.Transport)
var m1 = MyTestStruct{}
if err = t1.ReadString(&m1, s); err != nil {
return false, errors.New(fmt.Sprintf("Unable to Deserialize struct\n\t %s", err))
}
return compareStructs(m, m1)
}
func ProtocolTest2(test *testing.T, pf ProtocolFactory) (bool, error) {
t := NewTSerializer()
t.Protocol = pf.GetProtocol(t.Transport)
var m = MyTestStruct{}
m.On = false
m.B = int8(0)
m.Int16 = 1
m.Int32 = 2
m.Int64 = 3
m.D = 4.1
m.St = "Test"
m.Bin = make([]byte, 10)
m.StringMap = make(map[string]string, 5)
m.StringList = make([]string, 5)
m.StringSet = make(map[string]struct{}, 5)
m.E = 2
s, err := t.WriteString(context.Background(), &m)
if err != nil {
return false, errors.New(fmt.Sprintf("Unable to Serialize struct\n\t %s", err))
}
t1 := NewTDeserializer()
t1.Protocol = pf.GetProtocol(t1.Transport)
var m1 = MyTestStruct{}
if err = t1.ReadString(&m1, s); err != nil {
return false, errors.New(fmt.Sprintf("Unable to Deserialize struct\n\t %s", err))
}
return compareStructs(m, m1)
}
func TestSerializer(t *testing.T) {
var protocol_factories map[string]ProtocolFactory
protocol_factories = make(map[string]ProtocolFactory)
protocol_factories["Binary"] = NewTBinaryProtocolFactoryDefault()
protocol_factories["Compact"] = NewTCompactProtocolFactory()
//protocol_factories["SimpleJSON"] = NewTSimpleJSONProtocolFactory() - write only, can't be read back by design
protocol_factories["JSON"] = NewTJSONProtocolFactory()
var tests map[string]func(*testing.T, ProtocolFactory) (bool, error)
tests = make(map[string]func(*testing.T, ProtocolFactory) (bool, error))
tests["Test 1"] = ProtocolTest1
tests["Test 2"] = ProtocolTest2
//tests["Test 3"] = ProtocolTest3 // Example of how to add additional tests
for name, pf := range protocol_factories {
for test, f := range tests {
if s, err := f(t, pf); !s || err != nil {
t.Errorf("%s Failed for %s protocol\n\t %s", test, name, err)
}
}
}
}

View file

@ -0,0 +1,633 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
// Autogenerated by Thrift Compiler (0.12.0)
// DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING
/* THE FOLLOWING THRIFT FILE WAS USED TO CREATE THIS
enum MyTestEnum {
FIRST = 1,
SECOND = 2,
THIRD = 3,
FOURTH = 4,
}
struct MyTestStruct {
1: bool on,
2: byte b,
3: i16 int16,
4: i32 int32,
5: i64 int64,
6: double d,
7: string st,
8: binary bin,
9: map<string, string> stringMap,
10: list<string> stringList,
11: set<string> stringSet,
12: MyTestEnum e,
}
*/
import (
"fmt"
)
// (needed to ensure safety because of naive import list construction.)
var _ = ZERO
var _ = fmt.Printf
var GoUnusedProtection__ int
type MyTestEnum int64
const (
MyTestEnum_FIRST MyTestEnum = 1
MyTestEnum_SECOND MyTestEnum = 2
MyTestEnum_THIRD MyTestEnum = 3
MyTestEnum_FOURTH MyTestEnum = 4
)
func (p MyTestEnum) String() string {
switch p {
case MyTestEnum_FIRST:
return "FIRST"
case MyTestEnum_SECOND:
return "SECOND"
case MyTestEnum_THIRD:
return "THIRD"
case MyTestEnum_FOURTH:
return "FOURTH"
}
return "<UNSET>"
}
func MyTestEnumFromString(s string) (MyTestEnum, error) {
switch s {
case "FIRST":
return MyTestEnum_FIRST, nil
case "SECOND":
return MyTestEnum_SECOND, nil
case "THIRD":
return MyTestEnum_THIRD, nil
case "FOURTH":
return MyTestEnum_FOURTH, nil
}
return MyTestEnum(0), fmt.Errorf("not a valid MyTestEnum string")
}
func MyTestEnumPtr(v MyTestEnum) *MyTestEnum { return &v }
type MyTestStruct struct {
On bool `thrift:"on,1" json:"on"`
B int8 `thrift:"b,2" json:"b"`
Int16 int16 `thrift:"int16,3" json:"int16"`
Int32 int32 `thrift:"int32,4" json:"int32"`
Int64 int64 `thrift:"int64,5" json:"int64"`
D float64 `thrift:"d,6" json:"d"`
St string `thrift:"st,7" json:"st"`
Bin []byte `thrift:"bin,8" json:"bin"`
StringMap map[string]string `thrift:"stringMap,9" json:"stringMap"`
StringList []string `thrift:"stringList,10" json:"stringList"`
StringSet map[string]struct{} `thrift:"stringSet,11" json:"stringSet"`
E MyTestEnum `thrift:"e,12" json:"e"`
}
func NewMyTestStruct() *MyTestStruct {
return &MyTestStruct{}
}
func (p *MyTestStruct) GetOn() bool {
return p.On
}
func (p *MyTestStruct) GetB() int8 {
return p.B
}
func (p *MyTestStruct) GetInt16() int16 {
return p.Int16
}
func (p *MyTestStruct) GetInt32() int32 {
return p.Int32
}
func (p *MyTestStruct) GetInt64() int64 {
return p.Int64
}
func (p *MyTestStruct) GetD() float64 {
return p.D
}
func (p *MyTestStruct) GetSt() string {
return p.St
}
func (p *MyTestStruct) GetBin() []byte {
return p.Bin
}
func (p *MyTestStruct) GetStringMap() map[string]string {
return p.StringMap
}
func (p *MyTestStruct) GetStringList() []string {
return p.StringList
}
func (p *MyTestStruct) GetStringSet() map[string]struct{} {
return p.StringSet
}
func (p *MyTestStruct) GetE() MyTestEnum {
return p.E
}
func (p *MyTestStruct) Read(iprot TProtocol) error {
if _, err := iprot.ReadStructBegin(); err != nil {
return PrependError(fmt.Sprintf("%T read error: ", p), err)
}
for {
_, fieldTypeId, fieldId, err := iprot.ReadFieldBegin()
if err != nil {
return PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err)
}
if fieldTypeId == STOP {
break
}
switch fieldId {
case 1:
if err := p.readField1(iprot); err != nil {
return err
}
case 2:
if err := p.readField2(iprot); err != nil {
return err
}
case 3:
if err := p.readField3(iprot); err != nil {
return err
}
case 4:
if err := p.readField4(iprot); err != nil {
return err
}
case 5:
if err := p.readField5(iprot); err != nil {
return err
}
case 6:
if err := p.readField6(iprot); err != nil {
return err
}
case 7:
if err := p.readField7(iprot); err != nil {
return err
}
case 8:
if err := p.readField8(iprot); err != nil {
return err
}
case 9:
if err := p.readField9(iprot); err != nil {
return err
}
case 10:
if err := p.readField10(iprot); err != nil {
return err
}
case 11:
if err := p.readField11(iprot); err != nil {
return err
}
case 12:
if err := p.readField12(iprot); err != nil {
return err
}
default:
if err := iprot.Skip(fieldTypeId); err != nil {
return err
}
}
if err := iprot.ReadFieldEnd(); err != nil {
return err
}
}
if err := iprot.ReadStructEnd(); err != nil {
return PrependError(fmt.Sprintf("%T read struct end error: ", p), err)
}
return nil
}
func (p *MyTestStruct) readField1(iprot TProtocol) error {
if v, err := iprot.ReadBool(); err != nil {
return PrependError("error reading field 1: ", err)
} else {
p.On = v
}
return nil
}
func (p *MyTestStruct) readField2(iprot TProtocol) error {
if v, err := iprot.ReadByte(); err != nil {
return PrependError("error reading field 2: ", err)
} else {
temp := int8(v)
p.B = temp
}
return nil
}
func (p *MyTestStruct) readField3(iprot TProtocol) error {
if v, err := iprot.ReadI16(); err != nil {
return PrependError("error reading field 3: ", err)
} else {
p.Int16 = v
}
return nil
}
func (p *MyTestStruct) readField4(iprot TProtocol) error {
if v, err := iprot.ReadI32(); err != nil {
return PrependError("error reading field 4: ", err)
} else {
p.Int32 = v
}
return nil
}
func (p *MyTestStruct) readField5(iprot TProtocol) error {
if v, err := iprot.ReadI64(); err != nil {
return PrependError("error reading field 5: ", err)
} else {
p.Int64 = v
}
return nil
}
func (p *MyTestStruct) readField6(iprot TProtocol) error {
if v, err := iprot.ReadDouble(); err != nil {
return PrependError("error reading field 6: ", err)
} else {
p.D = v
}
return nil
}
func (p *MyTestStruct) readField7(iprot TProtocol) error {
if v, err := iprot.ReadString(); err != nil {
return PrependError("error reading field 7: ", err)
} else {
p.St = v
}
return nil
}
func (p *MyTestStruct) readField8(iprot TProtocol) error {
if v, err := iprot.ReadBinary(); err != nil {
return PrependError("error reading field 8: ", err)
} else {
p.Bin = v
}
return nil
}
func (p *MyTestStruct) readField9(iprot TProtocol) error {
_, _, size, err := iprot.ReadMapBegin()
if err != nil {
return PrependError("error reading map begin: ", err)
}
tMap := make(map[string]string, size)
p.StringMap = tMap
for i := 0; i < size; i++ {
var _key0 string
if v, err := iprot.ReadString(); err != nil {
return PrependError("error reading field 0: ", err)
} else {
_key0 = v
}
var _val1 string
if v, err := iprot.ReadString(); err != nil {
return PrependError("error reading field 0: ", err)
} else {
_val1 = v
}
p.StringMap[_key0] = _val1
}
if err := iprot.ReadMapEnd(); err != nil {
return PrependError("error reading map end: ", err)
}
return nil
}
func (p *MyTestStruct) readField10(iprot TProtocol) error {
_, size, err := iprot.ReadListBegin()
if err != nil {
return PrependError("error reading list begin: ", err)
}
tSlice := make([]string, 0, size)
p.StringList = tSlice
for i := 0; i < size; i++ {
var _elem2 string
if v, err := iprot.ReadString(); err != nil {
return PrependError("error reading field 0: ", err)
} else {
_elem2 = v
}
p.StringList = append(p.StringList, _elem2)
}
if err := iprot.ReadListEnd(); err != nil {
return PrependError("error reading list end: ", err)
}
return nil
}
func (p *MyTestStruct) readField11(iprot TProtocol) error {
_, size, err := iprot.ReadSetBegin()
if err != nil {
return PrependError("error reading set begin: ", err)
}
tSet := make(map[string]struct{}, size)
p.StringSet = tSet
for i := 0; i < size; i++ {
var _elem3 string
if v, err := iprot.ReadString(); err != nil {
return PrependError("error reading field 0: ", err)
} else {
_elem3 = v
}
p.StringSet[_elem3] = struct{}{}
}
if err := iprot.ReadSetEnd(); err != nil {
return PrependError("error reading set end: ", err)
}
return nil
}
func (p *MyTestStruct) readField12(iprot TProtocol) error {
if v, err := iprot.ReadI32(); err != nil {
return PrependError("error reading field 12: ", err)
} else {
temp := MyTestEnum(v)
p.E = temp
}
return nil
}
func (p *MyTestStruct) Write(oprot TProtocol) error {
if err := oprot.WriteStructBegin("MyTestStruct"); err != nil {
return PrependError(fmt.Sprintf("%T write struct begin error: ", p), err)
}
if err := p.writeField1(oprot); err != nil {
return err
}
if err := p.writeField2(oprot); err != nil {
return err
}
if err := p.writeField3(oprot); err != nil {
return err
}
if err := p.writeField4(oprot); err != nil {
return err
}
if err := p.writeField5(oprot); err != nil {
return err
}
if err := p.writeField6(oprot); err != nil {
return err
}
if err := p.writeField7(oprot); err != nil {
return err
}
if err := p.writeField8(oprot); err != nil {
return err
}
if err := p.writeField9(oprot); err != nil {
return err
}
if err := p.writeField10(oprot); err != nil {
return err
}
if err := p.writeField11(oprot); err != nil {
return err
}
if err := p.writeField12(oprot); err != nil {
return err
}
if err := oprot.WriteFieldStop(); err != nil {
return PrependError("write field stop error: ", err)
}
if err := oprot.WriteStructEnd(); err != nil {
return PrependError("write struct stop error: ", err)
}
return nil
}
func (p *MyTestStruct) writeField1(oprot TProtocol) (err error) {
if err := oprot.WriteFieldBegin("on", BOOL, 1); err != nil {
return PrependError(fmt.Sprintf("%T write field begin error 1:on: ", p), err)
}
if err := oprot.WriteBool(bool(p.On)); err != nil {
return PrependError(fmt.Sprintf("%T.on (1) field write error: ", p), err)
}
if err := oprot.WriteFieldEnd(); err != nil {
return PrependError(fmt.Sprintf("%T write field end error 1:on: ", p), err)
}
return err
}
func (p *MyTestStruct) writeField2(oprot TProtocol) (err error) {
if err := oprot.WriteFieldBegin("b", BYTE, 2); err != nil {
return PrependError(fmt.Sprintf("%T write field begin error 2:b: ", p), err)
}
if err := oprot.WriteByte(int8(p.B)); err != nil {
return PrependError(fmt.Sprintf("%T.b (2) field write error: ", p), err)
}
if err := oprot.WriteFieldEnd(); err != nil {
return PrependError(fmt.Sprintf("%T write field end error 2:b: ", p), err)
}
return err
}
func (p *MyTestStruct) writeField3(oprot TProtocol) (err error) {
if err := oprot.WriteFieldBegin("int16", I16, 3); err != nil {
return PrependError(fmt.Sprintf("%T write field begin error 3:int16: ", p), err)
}
if err := oprot.WriteI16(int16(p.Int16)); err != nil {
return PrependError(fmt.Sprintf("%T.int16 (3) field write error: ", p), err)
}
if err := oprot.WriteFieldEnd(); err != nil {
return PrependError(fmt.Sprintf("%T write field end error 3:int16: ", p), err)
}
return err
}
func (p *MyTestStruct) writeField4(oprot TProtocol) (err error) {
if err := oprot.WriteFieldBegin("int32", I32, 4); err != nil {
return PrependError(fmt.Sprintf("%T write field begin error 4:int32: ", p), err)
}
if err := oprot.WriteI32(int32(p.Int32)); err != nil {
return PrependError(fmt.Sprintf("%T.int32 (4) field write error: ", p), err)
}
if err := oprot.WriteFieldEnd(); err != nil {
return PrependError(fmt.Sprintf("%T write field end error 4:int32: ", p), err)
}
return err
}
func (p *MyTestStruct) writeField5(oprot TProtocol) (err error) {
if err := oprot.WriteFieldBegin("int64", I64, 5); err != nil {
return PrependError(fmt.Sprintf("%T write field begin error 5:int64: ", p), err)
}
if err := oprot.WriteI64(int64(p.Int64)); err != nil {
return PrependError(fmt.Sprintf("%T.int64 (5) field write error: ", p), err)
}
if err := oprot.WriteFieldEnd(); err != nil {
return PrependError(fmt.Sprintf("%T write field end error 5:int64: ", p), err)
}
return err
}
func (p *MyTestStruct) writeField6(oprot TProtocol) (err error) {
if err := oprot.WriteFieldBegin("d", DOUBLE, 6); err != nil {
return PrependError(fmt.Sprintf("%T write field begin error 6:d: ", p), err)
}
if err := oprot.WriteDouble(float64(p.D)); err != nil {
return PrependError(fmt.Sprintf("%T.d (6) field write error: ", p), err)
}
if err := oprot.WriteFieldEnd(); err != nil {
return PrependError(fmt.Sprintf("%T write field end error 6:d: ", p), err)
}
return err
}
func (p *MyTestStruct) writeField7(oprot TProtocol) (err error) {
if err := oprot.WriteFieldBegin("st", STRING, 7); err != nil {
return PrependError(fmt.Sprintf("%T write field begin error 7:st: ", p), err)
}
if err := oprot.WriteString(string(p.St)); err != nil {
return PrependError(fmt.Sprintf("%T.st (7) field write error: ", p), err)
}
if err := oprot.WriteFieldEnd(); err != nil {
return PrependError(fmt.Sprintf("%T write field end error 7:st: ", p), err)
}
return err
}
func (p *MyTestStruct) writeField8(oprot TProtocol) (err error) {
if err := oprot.WriteFieldBegin("bin", STRING, 8); err != nil {
return PrependError(fmt.Sprintf("%T write field begin error 8:bin: ", p), err)
}
if err := oprot.WriteBinary(p.Bin); err != nil {
return PrependError(fmt.Sprintf("%T.bin (8) field write error: ", p), err)
}
if err := oprot.WriteFieldEnd(); err != nil {
return PrependError(fmt.Sprintf("%T write field end error 8:bin: ", p), err)
}
return err
}
func (p *MyTestStruct) writeField9(oprot TProtocol) (err error) {
if err := oprot.WriteFieldBegin("stringMap", MAP, 9); err != nil {
return PrependError(fmt.Sprintf("%T write field begin error 9:stringMap: ", p), err)
}
if err := oprot.WriteMapBegin(STRING, STRING, len(p.StringMap)); err != nil {
return PrependError("error writing map begin: ", err)
}
for k, v := range p.StringMap {
if err := oprot.WriteString(string(k)); err != nil {
return PrependError(fmt.Sprintf("%T. (0) field write error: ", p), err)
}
if err := oprot.WriteString(string(v)); err != nil {
return PrependError(fmt.Sprintf("%T. (0) field write error: ", p), err)
}
}
if err := oprot.WriteMapEnd(); err != nil {
return PrependError("error writing map end: ", err)
}
if err := oprot.WriteFieldEnd(); err != nil {
return PrependError(fmt.Sprintf("%T write field end error 9:stringMap: ", p), err)
}
return err
}
func (p *MyTestStruct) writeField10(oprot TProtocol) (err error) {
if err := oprot.WriteFieldBegin("stringList", LIST, 10); err != nil {
return PrependError(fmt.Sprintf("%T write field begin error 10:stringList: ", p), err)
}
if err := oprot.WriteListBegin(STRING, len(p.StringList)); err != nil {
return PrependError("error writing list begin: ", err)
}
for _, v := range p.StringList {
if err := oprot.WriteString(string(v)); err != nil {
return PrependError(fmt.Sprintf("%T. (0) field write error: ", p), err)
}
}
if err := oprot.WriteListEnd(); err != nil {
return PrependError("error writing list end: ", err)
}
if err := oprot.WriteFieldEnd(); err != nil {
return PrependError(fmt.Sprintf("%T write field end error 10:stringList: ", p), err)
}
return err
}
func (p *MyTestStruct) writeField11(oprot TProtocol) (err error) {
if err := oprot.WriteFieldBegin("stringSet", SET, 11); err != nil {
return PrependError(fmt.Sprintf("%T write field begin error 11:stringSet: ", p), err)
}
if err := oprot.WriteSetBegin(STRING, len(p.StringSet)); err != nil {
return PrependError("error writing set begin: ", err)
}
for v := range p.StringSet {
if err := oprot.WriteString(string(v)); err != nil {
return PrependError(fmt.Sprintf("%T. (0) field write error: ", p), err)
}
}
if err := oprot.WriteSetEnd(); err != nil {
return PrependError("error writing set end: ", err)
}
if err := oprot.WriteFieldEnd(); err != nil {
return PrependError(fmt.Sprintf("%T write field end error 11:stringSet: ", p), err)
}
return err
}
func (p *MyTestStruct) writeField12(oprot TProtocol) (err error) {
if err := oprot.WriteFieldBegin("e", I32, 12); err != nil {
return PrependError(fmt.Sprintf("%T write field begin error 12:e: ", p), err)
}
if err := oprot.WriteI32(int32(p.E)); err != nil {
return PrependError(fmt.Sprintf("%T.e (12) field write error: ", p), err)
}
if err := oprot.WriteFieldEnd(); err != nil {
return PrependError(fmt.Sprintf("%T write field end error 12:e: ", p), err)
}
return err
}
func (p *MyTestStruct) String() string {
if p == nil {
return "<nil>"
}
return fmt.Sprintf("MyTestStruct(%+v)", *p)
}

View file

@ -0,0 +1,35 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
type TServer interface {
ProcessorFactory() TProcessorFactory
ServerTransport() TServerTransport
InputTransportFactory() TTransportFactory
OutputTransportFactory() TTransportFactory
InputProtocolFactory() TProtocolFactory
OutputProtocolFactory() TProtocolFactory
// Starts the server
Serve() error
// Stops the server. This is optional on a per-implementation basis. Not
// all servers are required to be cleanly stoppable.
Stop() error
}

View file

@ -0,0 +1,137 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"net"
"sync"
"time"
)
type TServerSocket struct {
listener net.Listener
addr net.Addr
clientTimeout time.Duration
// Protects the interrupted value to make it thread safe.
mu sync.RWMutex
interrupted bool
}
func NewTServerSocket(listenAddr string) (*TServerSocket, error) {
return NewTServerSocketTimeout(listenAddr, 0)
}
func NewTServerSocketTimeout(listenAddr string, clientTimeout time.Duration) (*TServerSocket, error) {
addr, err := net.ResolveTCPAddr("tcp", listenAddr)
if err != nil {
return nil, err
}
return &TServerSocket{addr: addr, clientTimeout: clientTimeout}, nil
}
// Creates a TServerSocket from a net.Addr
func NewTServerSocketFromAddrTimeout(addr net.Addr, clientTimeout time.Duration) *TServerSocket {
return &TServerSocket{addr: addr, clientTimeout: clientTimeout}
}
func (p *TServerSocket) Listen() error {
p.mu.Lock()
defer p.mu.Unlock()
if p.IsListening() {
return nil
}
l, err := net.Listen(p.addr.Network(), p.addr.String())
if err != nil {
return err
}
p.listener = l
return nil
}
func (p *TServerSocket) Accept() (TTransport, error) {
p.mu.RLock()
interrupted := p.interrupted
p.mu.RUnlock()
if interrupted {
return nil, errTransportInterrupted
}
p.mu.Lock()
listener := p.listener
p.mu.Unlock()
if listener == nil {
return nil, NewTTransportException(NOT_OPEN, "No underlying server socket")
}
conn, err := listener.Accept()
if err != nil {
return nil, NewTTransportExceptionFromError(err)
}
return NewTSocketFromConnTimeout(conn, p.clientTimeout), nil
}
// Checks whether the socket is listening.
func (p *TServerSocket) IsListening() bool {
return p.listener != nil
}
// Connects the socket, creating a new socket object if necessary.
func (p *TServerSocket) Open() error {
p.mu.Lock()
defer p.mu.Unlock()
if p.IsListening() {
return NewTTransportException(ALREADY_OPEN, "Server socket already open")
}
if l, err := net.Listen(p.addr.Network(), p.addr.String()); err != nil {
return err
} else {
p.listener = l
}
return nil
}
func (p *TServerSocket) Addr() net.Addr {
if p.listener != nil {
return p.listener.Addr()
}
return p.addr
}
func (p *TServerSocket) Close() error {
var err error
p.mu.Lock()
if p.IsListening() {
err = p.listener.Close()
p.listener = nil
}
p.mu.Unlock()
return err
}
func (p *TServerSocket) Interrupt() error {
p.mu.Lock()
p.interrupted = true
p.mu.Unlock()
p.Close()
return nil
}

View file

@ -0,0 +1,60 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"fmt"
"testing"
)
func TestSocketIsntListeningAfterInterrupt(t *testing.T) {
host := "127.0.0.1"
port := 9090
addr := fmt.Sprintf("%s:%d", host, port)
socket := CreateServerSocket(t, addr)
socket.Listen()
socket.Interrupt()
newSocket := CreateServerSocket(t, addr)
err := newSocket.Listen()
defer newSocket.Interrupt()
if err != nil {
t.Fatalf("Failed to rebinds: %s", err)
}
}
func TestSocketConcurrency(t *testing.T) {
host := "127.0.0.1"
port := 9090
addr := fmt.Sprintf("%s:%d", host, port)
socket := CreateServerSocket(t, addr)
go func() { socket.Listen() }()
go func() { socket.Interrupt() }()
}
func CreateServerSocket(t *testing.T, addr string) *TServerSocket {
socket, err := NewTServerSocket(addr)
if err != nil {
t.Fatalf("Failed to create server socket: %s", err)
}
return socket
}

Some files were not shown because too many files have changed in this diff Show more