Compare commits

..

109 commits

Author SHA1 Message Date
lawwong1
d2fd7b9ba9
merge retry mechanism change from gorealis v1 to gorealis v2 (#21) 2023-01-26 13:36:40 -08:00
lawwong1
8db625730f
Support Australis API to get aurora master nodes and mesos master nodes (#20) 2022-08-24 08:51:12 -07:00
Tan N. Le
e33a2d99d8
release 2.28.0 (#19) 2022-08-02 09:55:36 -07:00
Tan N. Le
4258634ccf
Capacity report (#18)
- pull capacity report via /offers endpoint.
- calculate how many tasks (with resource and constraints) can be fit in the cluster.
examples of using the above 2 features are in aurora-scheduler/australis#33
2022-07-28 19:27:53 -07:00
Tan N. Le
5d0998647a
default policy for slaDrainHosts (#17) 2021-11-01 18:15:51 -07:00
Renán I. Del Valle
907430768c
Misc. fixes for tests (#16)
* Bumping up CI to go1.17 and enabling CI for PRs.

* Adding go.sum now that issues seem to have gone away.

* Bump up aurora to 0.25.0 and mesos to 1.9.0

* Fixing Mac tests. Adding extra time for killing thermos jobs.

* Reduce the thermos overhead for unit tests

Co-authored-by: lenhattan86 <lenhattan86@users.noreply.github.com>
2021-10-25 12:39:13 -07:00
lenhattan86
fe664178ce
Add tier & production in task config (#14) 2021-10-15 12:18:26 -07:00
lenhattan86
a75b691d72
Merge pull request #15 from lenhattan86/fix_unit_test
Fix unit test for GetJobSummary
2021-10-07 13:14:25 -07:00
lenhattan86
306603795b fix unit test error for GetJobSummary 2021-10-06 22:37:54 -07:00
lenhattan86
045a4869a5
Merge branch 'aurora-scheduler:master' into master 2021-10-06 14:01:03 -07:00
lenhattan86
425faf28b8
Adds priority for aurora-scheduler (#13)
Adds priority for task config
2021-09-16 16:29:25 -07:00
lenhattan86
2d81147aaa Merge branch 'aurora-scheduler:master' into master 2021-06-01 21:35:42 -07:00
Renán I. Del Valle
983bf44b9f
Update thrift to 0.14.0 (#9)
Generated thrift stubs using 0.14.0 compiler version.
Script now tells user to use version 0.14.0 of thrift compiler.
2021-03-01 16:52:25 -08:00
Renán I. Del Valle
d0be43b8ac
Dropping support for dep (#10)
Dep files are no longer necessary.
2021-03-01 15:36:28 -08:00
lenhattan86
b1661698c2
GetJobSummary API (#8)
* Adds GetJobSummary API
2021-01-12 16:18:09 -08:00
lenhattan86
364ee93202
Merge pull request #1 from aurora-scheduler/master
pull from upstream
2021-01-12 15:09:36 -08:00
Renan DelValle
755f99fb76
Fixes style issue with jobupdate file. 2020-11-16 21:51:02 -08:00
Renán I. Del Valle
caf1444250
Removes variables from github actions
Github Actions deprecated support for using env files without previously setting them. Adjusting CI scripts accordingly.
2020-11-16 21:45:00 -08:00
lenhattan86
c3dbeba2bd
Adds ability to fetch Mesos Master leader (#7)
* Adds ability to fetch Mesos Master leader from Zookeeper
2020-11-15 16:44:21 -08:00
Renan DelValle
6c639362c8
Bumping up go tests time to 30m. 2020-07-27 21:44:35 -07:00
Renán I. Del Valle
4cf60775f5
Bumping up thrift go library version to v0.13.2 (#6)
Thrift v0.13.2 is a forked version of v0.13.0 with a patch to not panic when trying to write to a closed buffer. Instead we get an error back and we can handle it appropriately.
2020-05-26 20:32:51 -07:00
Renán I. Del Valle
30f804bc53
Update using-the-sample-client.md
Fixing typo on doc.
2020-05-20 18:26:33 -07:00
Renán I. Del Valle
e5d63579e8
Update using-the-sample-client.md
Updating instructions for using the sample client.
2020-05-20 18:26:05 -07:00
Renán I. Del Valle
34a950306d
Update developing.md
Updating documentation for developing gorealis
2020-05-20 18:21:47 -07:00
Renan DelValle
851f9686b6
Bumping up version. 2020-05-07 11:35:22 -07:00
Renan DelValle
72c04220fe
Removing vendoring folder now that packages are cached in goproxy. 2020-05-07 11:33:47 -07:00
Renan DelValle
96384e6fdc
Renaming function in job update to reduce ambiguity about which underlying object it is modifiying. Adding support for specifying ranges in updates as well as SLA Awareness. 2020-05-07 11:32:36 -07:00
Renán I. Del Valle
ed81bcb28d
Increasing test time out to 20 mins
Increasing test timeout to 20mins for CI
2020-05-05 23:02:44 -07:00
Renán I. Del Valle
69ced895e2
Upgrade to Aurora 0.22.0 (#5)
* Upgrading to Thrift 0.13.1. This version is a fork of 0.13.0 with a patch on top of it to fix an issue where trying a realis call after the connection has been closed results in a panic.

* Upgrading compose set up to Mesos 1.6.2 and Aurora 0.22.0.

* Adding support for using different update strategies.

* Adding a monitor that is friendly with auto pause.

* Adding tests for new update strategies.
2020-05-05 20:55:25 -07:00
Renán I. Del Valle
1d8afcd329
Adding github actions CI (#4)
* CI will run on pushes to the main branch.
2020-02-26 11:37:39 -08:00
Renán I. Del Valle
406640c7a9
Add a few items to gitignore. Change few missed dependencies to point to aurora-scheduler repository. (#3) 2020-02-19 12:01:02 -08:00
Renán I. Del Valle
02710e5434
Moving repository to aurora-scheduler organization. (#2)
gorealis v2 will now live in the aurora-scheduler organization
2020-02-19 11:40:40 -08:00
Renán I. Del Valle
3a6a93f946
Changing module address to be under aurora-scheduler (#1)
The v2 version of gorealis will now be housed under the aurora-scheduler organization.
2020-02-18 18:07:17 -08:00
Renán I. Del Valle
fc983fa096
Avoid panics using a forked Thrift version while we wait for Thrift to release 0.14.0 (#119)
* Changing README to point to the incarnation of the aurora scheduler project.

* Pointing to a forked patch version of thrift using mod while we wait for the fix that will land in 0.14.0.
2020-02-18 14:18:45 -08:00
Renan I. Del Valle
7b0c75450b
Removing go sum
Since go has launched a checksum database, it is no longer necessary to store go.sum file.
https://blog.golang.org/module-mirror-launch
2020-02-05 13:41:02 -08:00
Renan DelValle
235f854087 Changing calls on functions that use JobUpdateKey to reflect change made for memory safety. 2019-09-25 17:20:30 -07:00
Renan DelValle
4fc4953ec4 Change JobUpdateKey pointers to be literals, then we deep copy the JobKey pointer to a new JobKey in order to avoid side effects. 2019-09-25 17:20:30 -07:00
Renan DelValle
119d1c429b Moving client configuration options from realis to its own sepearate file to make code more digestible. 2019-09-25 17:20:30 -07:00
Renan DelValle
a8a7cf779f Splitting realis into regular API and admin API files. 2019-09-25 17:20:30 -07:00
Renan DelValle
98f2cab4a2 Renamed Aurora address validator to be less redudnant. Added tests cribbed from version 1. 2019-09-25 17:20:30 -07:00
Renan DelValle
09628391cc Cleaning up error messages and some formatting. 2019-09-25 17:20:30 -07:00
Renan DelValle
f72fdacfb0 Changing the names of the protocol constants to be more descriptive. 2019-09-25 17:20:30 -07:00
Renan DelValle
55cf9bcb70 Adding more fine grained controls to retry mechanism. Retry mechanism may now be configured to not retry if an error is hit or to specifically stop retrying if a timeout error is encountered. 2019-09-25 17:20:30 -07:00
Renan DelValle
fe4a0dc06e Minor error message clarification 2019-09-25 17:20:30 -07:00
Renan DelValle
d67b8ca1d7 Removing uncessary functions which previously handled initializing thrift protocol. Changed how which protocol is chosen based upon configuration. 2019-09-25 17:20:30 -07:00
Renan DelValle
ecd59f7a8d Removing unnecessary cookie jar from thrift protocol initialization. 2019-09-25 17:20:30 -07:00
Renan DelValle
5d75dcc15e Adding MonitorJobUpdateQuery which serves as the basis for other monitors. 2019-09-25 17:20:30 -07:00
Renan DelValle
9a70711537 Making JobUpdate synchronous. MonitorJobUpdateStatus creates a local copy of job key in order to guard against side effects cuased by mutations to the JobKey being performed externally. 2019-09-25 17:20:30 -07:00
Renan DelValle
203f178d68 Changing error messages to be lower case in realis API. 2019-09-25 17:20:30 -07:00
Renan DelValle
9584266b71 Changing MonitorJobUpdate to use MOnitorJobUpdateStatus under the hood. 2019-09-25 17:20:30 -07:00
Renan DelValle
6f20f5b62f Adding JobUpdateStatus monitor as well as renaming all monitor functions to be Monitor + <subject> 2019-09-25 17:20:30 -07:00
Renan DelValle
04471c6918 Adding trace logging. 2019-09-25 17:20:30 -07:00
Renan DelValle
dbad078d95 Adding missing indirection for adding GPU requirements to task. 2019-09-25 17:20:30 -07:00
Renan DelValle
b9db36520c Adding realis test file which currently tests the get certs function. 2019-09-25 17:20:30 -07:00
Renan DelValle
2c795debfd Updating runtestMac to be the same as in the gorealis v1 repo. 2019-09-25 17:20:30 -07:00
Robert Allen
c553f67d4e Adding support for PartitionPolicy. 2019-09-25 17:20:30 -07:00
Renan DelValle
461b23400c
V2.0 thrift repository migration and cleanup (#98)
* Remove unnecessary files from the thrift repository that come along with the go library.

* Updating thrift generated code to be 0.12.0 final generated code.

* Remove git.apache.org dependency in vendor folder.

* Migrating from git.apache.org/thrift.git to github.com/apache/thrift

* Upgrading dep (although it will not work now that imports are using mod format, it allows for users to easily fix this with a replacement of the import path).

* Upgrading mod dependencies for Thrift to point to github.com location of the repository.

* Bug fix for Thermos Payload generation relating to the GPU being set.
2019-02-19 16:40:41 -08:00
Renan DelValle
9b3593e9d9
Fixing GPU resource to only be added if specified since Aurora scheduler by default will reject tasks containing GPU. 2019-01-08 17:47:13 -08:00
Renan DelValle
8d67d8c2f3
Releasing 2.0.1. 2019-01-08 15:59:56 -08:00
Renan DelValle
e13349db26
Initial support for Thermos and GPU resources. 2019-01-07 14:39:47 -08:00
Renan DelValle
afcdaa84b8
Initial support for generating Thermos data objects. 2018-12-28 11:46:14 -08:00
Renan DelValle
51597ecb32
Changing paths to refer to gorealis v2 in order for dependencies to be correct. 2018-12-27 10:09:22 -08:00
Renan DelValle
acbe9ad9e5
Upgrading vendor folder dependencies. 2018-12-27 09:58:53 -08:00
Renan DelValle
4a0cbcd770
Updating codecov badge to point to the right placce. 2018-12-26 18:17:05 -08:00
Renan DelValle
b776bd301d
Adding v2 to module. 2018-12-23 12:44:26 -08:00
Renan DelValle
e4e8a1c0b3
Adding a check for 401. This reduces the retries on the end to end test and fails fast when a wrong/unathorized username and password are provided to interact with Aurora. 2018-12-18 17:14:48 -08:00
Renan DelValle
71d41de2e4
Fixing bug for logger which passed everything as an array instead of unrolling the array to the printer. 2018-12-18 16:41:31 -08:00
Renan DelValle
84e8762495
Refactoring URL validation tets to be more terse as suggested by Pinglei. 2018-12-18 12:44:08 -08:00
Renan DelValle
11c71b0463
Upgrading container where MacOS tests run to 1.11. Upgrading tuo thrift 12 for binding generation. 2018-12-18 12:38:58 -08:00
Renan DelValle
8f9a678b7d
Using more golang standard constant naming. 2018-12-18 12:38:25 -08:00
Renan DelValle
fdd94e9bea
Adding a shiro.ini configuration in order to test bad password using compose setup. 2018-12-18 12:37:50 -08:00
Renan DelValle
67b37d5a42
Improving detection of protocol to not accidentally add one protocol in front of the other. 2018-12-17 18:06:40 -08:00
Renan DelValle
56b325ed80
Aurora endpoint may now be explicitly provided with or without protocol and with or without port. 2018-12-17 18:00:20 -08:00
Renan DelValle
ef421f60c3
Adding mod support to gorealis. 2018-12-12 19:06:51 -08:00
Renan DelValle
c4691c7347
Bumping travis CI to go 1.11 2018-12-12 14:34:46 -08:00
Renan DelValle
533591ab89
Ran project through newest goimports. 2018-12-12 14:25:06 -08:00
Renan DelValle
0c00765995
Refactoring tests to reflect API changes. 2018-12-12 14:14:58 -08:00
Renan DelValle
0b43a58b15
Refactoring test to reduce code size. 2018-12-12 14:14:31 -08:00
Renan DelValle
992e52eba2
Changing realis API to use new JobUpdate struct and to use concrete JobKey types. 2018-12-12 14:13:45 -08:00
Renan DelValle
0c32a7e683
Refactored client.go example to match new api. Fixed typeo in jobUpdate function JobUpdateFromConfig. 2018-12-12 14:12:31 -08:00
Renan DelValle
e1906542a6
Allowing task and job to return job keys. Job keys are now passed around as concrete types (not pointers) due to the possiblity of side effects being cause if pointers to job keys inside of another struct are passed around. Cloning now uses the TaskFromThrift method to do a deep copy of an AuroraTask. 2018-12-12 14:01:26 -08:00
Renan DelValle
005980fc44
Refactor of update job code to use an AuroraTask underneath it and forward the necessary pointer receivers down to the AuroraTask. Code and tests for doing a deep copy of AuroraTask have been included. 2018-12-11 17:45:49 -08:00
Renan DelValle
98b4061513
Renamed Task to AuroraTask to avoid confusion with Mesos tasks. Added constants to access certain resources to avoid confusion and to ensure compile time safety. 2018-12-11 16:51:50 -08:00
Renan DelValle
e00e0a0492
Changing all containers to use pointer receiver since they are sharing around a single pointer inside the struct, I want to convey to users that the data is all being shared by a pointer. 2018-12-11 16:49:37 -08:00
Renan DelValle
5836ede37b
Splitting off Aurora task from Aurora Job since Update mechanism only needs task. 2018-12-10 18:57:16 -08:00
Renan DelValle
b0c25e9013
Refactor updatejob to JobUpdate to be more in line with Aurora terminology. 2018-12-10 18:13:28 -08:00
Renan DelValle
76300782ba
Renaming RealisClient to Client to avoid stuttering. Moving monitors under Client. Making configuration object private. Deleted legacy code to generate configuration object. 2018-12-08 08:57:15 -08:00
Renan DelValle
c1be2fe62b
Monitors are now all pointer receivers for RealisClient. 2018-12-07 16:08:49 -08:00
Renan DelValle
133938b307
Adding Tier. 2018-12-07 16:01:23 -08:00
Renan DelValle
c071e5ca62
Updating json client to use new API. 2018-12-04 15:19:08 -08:00
Renan DelValle
c00b83b14c
Making changes to sample client to match the refactoring done to main library. 2018-12-04 15:17:22 -08:00
Renan DelValle
47d955d4a4
Adding Gopkg.lock to ignore from diff file for github. 2018-11-29 17:47:12 -08:00
Renan DelValle
99b03c1254
Remove vendored folder and gen-go from Github diffs. 2018-11-29 17:45:19 -08:00
Renan DelValle
7967270b3b
Refactoring NewJob to use struct literals for clarity. 2018-11-29 17:06:45 -08:00
Renan DelValle
54378b2d8a
Changing the signature for some API. Specifically, result objects that hold a single variable are now returning that variable instead of a result object. Tests have been refcatored to use new v2 API. All tests are currently passing. 2018-11-28 20:13:49 -08:00
Renan DelValle
59e3a7065e
Refactoring code to be compatible with Thrift 0.12.0 generated code. Tests are still not refactored. 2018-11-27 18:45:10 -08:00
Renan DelValle
cec9c001fb
Upgrading dependencies 2018-11-27 18:44:33 -08:00
Renan DelValle
366599fb80
Regenerating Thrift bindings with Thrift 0.12.0 2018-11-27 18:05:33 -08:00
Renan DelValle
356978cb42
Upgrading dependency to Thrift 0.12.0 2018-11-27 18:03:50 -08:00
Renan DelValle
3e4590dcc0
Changing monitors to use time.Duration to be more explicit in code and to have tighter control. 2018-11-22 14:03:51 -08:00
Renan DelValle
b6effe66b7
Moving cluster struct factory from realis to clusters. 2018-11-22 12:23:20 -08:00
Renan DelValle
848b5f7971
Eliminating deprecated response code check since retry call does this automatically. 2018-11-22 12:23:20 -08:00
Renan DelValle
d747a48626
Simplifying API. Many API calls have gone from a tuple of two returns to a single return. 2018-11-22 12:23:18 -08:00
Renan DelValle
573e45a59c
Simplifying code to use bare structs. 2018-11-22 12:22:26 -08:00
Renan DelValle
8a9a97c150
Removing unnecessary interface from Aurora Job. 2018-11-22 12:22:26 -08:00
Renan DelValle
1146736c2b
Refactoring variable names and variable types to saner versions. 2018-11-22 12:22:25 -08:00
Renan DelValle
c65a47f6e2
Changing Certspath to CertsPath 2018-11-22 12:22:25 -08:00
Renan DelValle
4471c62659
Removing retries as an option since it's a dup of Backoff. 2018-11-22 12:22:25 -08:00
Renan DelValle
a23bd1b2cc
Shedding interface because there is no good reason to have it. 2018-11-22 12:22:22 -08:00
228 changed files with 15047 additions and 48967 deletions

View file

@ -0,0 +1,5 @@
[users]
aurora = secret, admin
[roles]
admin = *

View file

@ -1 +1 @@
0.22.0
0.26.0

32
.github/workflows/main.yml vendored Normal file
View file

@ -0,0 +1,32 @@
name: CI
on:
push:
branches:
- master
pull_request:
branches:
- master
jobs:
build:
runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@v2
- name: Setup Go for use with actions
uses: actions/setup-go@v2
with:
go-version: 1.17
- name: Install goimports
run: go get golang.org/x/tools/cmd/goimports
- name: Set env with list of directories in repo containin go code
run: echo GO_USR_DIRS=$(go list -f {{.Dir}} ./... | grep -E -v "/gen-go/|/vendor/") >> $GITHUB_ENV
- name: Run goimports check
run: test -z "`for d in $GO_USR_DIRS; do goimports -d $d/*.go | tee /dev/stderr; done`"
- name: Create aurora/mesos docker cluster
run: docker-compose up -d
- name: Run tests
run: go test -timeout 35m -race -coverprofile=coverage.txt -covermode=atomic -v github.com/aurora-scheduler/gorealis/v2

4
.gitignore vendored
View file

@ -37,3 +37,7 @@ _testmain.go
# Output of the go coverage tool, specifically when used with LiteIDE
*.out
# Example client build
examples/client
examples/jsonClient

View file

@ -1,71 +0,0 @@
# This file contains all available configuration options
# with their default values.
# options for analysis running
run:
# default concurrency is a available CPU number
concurrency: 4
# timeout for analysis, e.g. 30s, 5m, default is 1m
deadline: 1m
# exit code when at least one issue was found, default is 1
issues-exit-code: 1
# include test files or not, default is true
tests: true
skip-dirs:
- gen-go/
# output configuration options
output:
# colored-line-number|line-number|json|tab|checkstyle|code-climate, default is "colored-line-number"
format: colored-line-number
# print lines of code with issue, default is true
print-issued-lines: true
# print linter name in the end of issue text, default is true
print-linter-name: true
# all available settings of specific linters
linters-settings:
errcheck:
# report about not checking of errors in type assetions: `a := b.(MyStruct)`;
# default is false: such cases aren't reported by default.
check-type-assertions: true
# report about assignment of errors to blank identifier: `num, _ := strconv.Atoi(numStr)`;
# default is false: such cases aren't reported by default.
check-blank: true
govet:
# report about shadowed variables
check-shadowing: true
goconst:
# minimal length of string constant, 3 by default
min-len: 3
# minimal occurrences count to trigger, 3 by default
min-occurrences: 2
misspell:
# Correct spellings using locale preferences for US or UK.
# Default is to use a neutral variety of English.
# Setting locale to US will correct the British spelling of 'colour' to 'color'.
locale: US
lll:
# max line length, lines longer will be reported. Default is 120.
# '\t' is counted as 1 character by default, and can be changed with the tab-width option
line-length: 120
# tab width in spaces. Default to 1.
tab-width: 4
linters:
enable:
- govet
- goimports
- golint
- lll
- goconst
enable-all: false
fast: false

View file

@ -1,16 +1,9 @@
sudo: required
dist: xenial
language: go
branches:
only:
- master
- master-v2.0
- future
go:
- "1.10.x"
- "1.11.x"
env:
global:
@ -27,7 +20,7 @@ install:
- docker-compose up -d
script:
- go test -race -coverprofile=coverage.txt -covermode=atomic -v github.com/paypal/gorealis
- go test -race -coverprofile=coverage.txt -covermode=atomic -v github.com/aurora-scheduler/gorealis
after_success:
- bash <(curl -s https://codecov.io/bash)

View file

@ -1,25 +0,0 @@
1.22.0 (unreleased)
* CreateService and StartJobUpdate do not continue retrying if a timeout has been encountered
by the HTTP client. Instead they now return an error that conforms to the Timedout interface.
Users can check for a Timedout error by using `realis.IsTimeout(err)`.
* New API function VariableBatchStep has been added which returns the current batch at which
a Variable Batch Update configured Update is currently in.
* Added new PauseUpdateMonitor which monitors an update until it is an `ROLL_FORWARD_PAUSED` state.
* Added variableBatchStep command to sample client to be used for testing new VariableBatchStep api.
* JobUpdateStatus has changed function signature from:
`JobUpdateStatus(updateKey aurora.JobUpdateKey, desiredStatuses map[aurora.JobUpdateStatus]bool, interval, timeout time.Duration) (aurora.JobUpdateStatus, error)`
to
`JobUpdateStatus(updateKey aurora.JobUpdateKey, desiredStatuses []aurora.JobUpdateStatus, interval, timeout time.Duration) (aurora.JobUpdateStatus, error)`
* Added TerminalUpdateStates function which returns an slice containing all UpdateStates which are considered terminal states.
1.21.0
* Version numbering change. Future versions will be labled X.Y.Z where X is the major version, Y is the Aurora version the library has been tested against (e.g. 21 -> 0.21.0), and X is the minor revision.
* Moved to Thrift 0.12.0 code generator and go library.
* `aurora.ACTIVE_STATES`, `aurora.SLAVE_ASSIGNED_STATES`, `aurora.LIVE_STATES`, `aurora.TERMINAL_STATES`, `aurora.ACTIVE_JOB_UPDATE_STATES`, `aurora.AWAITNG_PULSE_JOB_UPDATE_STATES` are all now generated as a slices.
* Please use `realis.ActiveStates`, `realis.SlaveAssignedStates`,`realis.LiveStates`, `realis.TerminalStates`, `realis.ActiveJobUpdateStates`, `realis.AwaitingPulseJobUpdateStates` in their places when map representations are needed.
* `GetInstanceIds(key *aurora.JobKey, states map[aurora.ScheduleStatus]bool) (map[int32]bool, error)` has changed signature to ` GetInstanceIds(key *aurora.JobKey, states []aurora.ScheduleStatus) ([]int32, error)`
* Adding support for GPU as resource.
* Changing compose environment to Aurora snapshot in order to support staggered update.
* Adding staggered updates API.

64
Gopkg.lock generated
View file

@ -1,64 +0,0 @@
# This file is autogenerated, do not edit; changes may be undone by the next 'dep ensure'.
[[projects]]
digest = "1:89696c38cec777120b8b1bb5e2d363d655cf2e1e7d8c851919aaa0fd576d9b86"
name = "github.com/apache/thrift"
packages = ["lib/go/thrift"]
pruneopts = ""
revision = "384647d290e2e4a55a14b1b7ef1b7e66293a2c33"
version = "v0.12.0"
[[projects]]
digest = "1:56c130d885a4aacae1dd9c7b71cfe39912c7ebc1ff7d2b46083c8812996dc43b"
name = "github.com/davecgh/go-spew"
packages = ["spew"]
pruneopts = ""
revision = "346938d642f2ec3594ed81d874461961cd0faa76"
version = "v1.1.0"
[[projects]]
digest = "1:df48fb76fb2a40edea0c9b3d960bc95e326660d82ff1114e1f88001f7a236b40"
name = "github.com/pkg/errors"
packages = ["."]
pruneopts = ""
revision = "e881fd58d78e04cf6d0de1217f8707c8cc2249bc"
[[projects]]
digest = "1:256484dbbcd271f9ecebc6795b2df8cad4c458dd0f5fd82a8c2fa0c29f233411"
name = "github.com/pmezard/go-difflib"
packages = ["difflib"]
pruneopts = ""
revision = "792786c7400a136282c1664665ae0a8db921c6c2"
version = "v1.0.0"
[[projects]]
digest = "1:78bea5e26e82826dacc5fd64a1013a6711b7075ec8072819b89e6ad76cb8196d"
name = "github.com/samuel/go-zookeeper"
packages = ["zk"]
pruneopts = ""
revision = "471cd4e61d7a78ece1791fa5faa0345dc8c7d5a5"
[[projects]]
digest = "1:381bcbeb112a51493d9d998bbba207a529c73dbb49b3fd789e48c63fac1f192c"
name = "github.com/stretchr/testify"
packages = [
"assert",
"require",
]
pruneopts = ""
revision = "ffdc059bfe9ce6a4e144ba849dbedead332c6053"
version = "v1.3.0"
[solve-meta]
analyzer-name = "dep"
analyzer-version = 1
input-imports = [
"github.com/apache/thrift/lib/go/thrift",
"github.com/pkg/errors",
"github.com/samuel/go-zookeeper/zk",
"github.com/stretchr/testify/assert",
"github.com/stretchr/testify/require",
]
solver-name = "gps-cdcl"
solver-version = 1

View file

@ -1,16 +0,0 @@
[[constraint]]
name = "github.com/apache/thrift"
version = "0.12.0"
[[constraint]]
name = "github.com/pkg/errors"
revision = "e881fd58d78e04cf6d0de1217f8707c8cc2249bc"
[[constraint]]
name = "github.com/samuel/go-zookeeper"
revision = "471cd4e61d7a78ece1791fa5faa0345dc8c7d5a5"
[[constraint]]
name = "github.com/stretchr/testify"
version = "1.3.0"

View file

@ -1,6 +1,6 @@
# gorealis [![GoDoc](https://godoc.org/github.com/paypal/gorealis?status.svg)](https://godoc.org/github.com/paypal/gorealis) [![Build Status](https://travis-ci.org/paypal/gorealis.svg?branch=master)](https://travis-ci.org/paypal/gorealis) [![codecov](https://codecov.io/gh/paypal/gorealis/branch/master/graph/badge.svg)](https://codecov.io/gh/paypal/gorealis)
# gorealis [![GoDoc](https://godoc.org/github.com/aurora-scheduler/gorealis?status.svg)](https://godoc.org/github.com/aurora-scheduler/gorealis) [![codecov](https://codecov.io/gh/aurora-scheduler/gorealis/branch/master/graph/badge.svg)](https://codecov.io/gh/aurora-scheduler/gorealis/branch/master)
Go library for interacting with [Apache Aurora](https://github.com/apache/aurora).
Go library for interacting with [Aurora Scheduler](https://github.com/aurora-scheduler/aurora).
### Aurora version compatibility
Please see [.auroraversion](./.auroraversion) to see the latest Aurora version against which this
@ -14,7 +14,7 @@ library has been tested.
## Projects using gorealis
* [australis](https://github.com/rdelval/australis)
* [australis](https://github.com/aurora-scheduler/australis)
## Contributions
Contributions are always welcome. Please raise an issue to discuss a contribution before it is made.

View file

@ -21,8 +21,6 @@ import (
"github.com/pkg/errors"
)
// Cluster contains the definition of the clusters.json file used by the default Aurora
// client for configuration
type Cluster struct {
Name string `json:"name"`
AgentRoot string `json:"slave_root"`
@ -30,13 +28,13 @@ type Cluster struct {
ZK string `json:"zk"`
ZKPort int `json:"zk_port"`
SchedZKPath string `json:"scheduler_zk_path"`
MesosZKPath string `json:"mesos_zk_path"`
SchedURI string `json:"scheduler_uri"`
ProxyURL string `json:"proxy_url"`
AuthMechanism string `json:"auth_mechanism"`
}
// LoadClusters loads clusters.json file traditionally located at /etc/aurora/clusters.json
// for use with a gorealis client
// Loads clusters.json file traditionally located at /etc/aurora/clusters.json
func LoadClusters(config string) (map[string]Cluster, error) {
file, err := os.Open(config)
@ -57,3 +55,15 @@ func LoadClusters(config string) (map[string]Cluster, error) {
return m, nil
}
func GetDefaultClusterFromZKUrl(zkURL string) *Cluster {
return &Cluster{
Name: "defaultCluster",
AuthMechanism: "UNAUTHENTICATED",
ZK: zkURL,
SchedZKPath: "/aurora/scheduler",
MesosZKPath: "/mesos",
AgentRunDir: "latest",
AgentRoot: "/var/lib/mesos",
}
}

View file

@ -18,7 +18,7 @@ import (
"fmt"
"testing"
realis "github.com/paypal/gorealis"
realis "github.com/aurora-scheduler/gorealis/v2"
"github.com/stretchr/testify/assert"
)
@ -32,6 +32,7 @@ func TestLoadClusters(t *testing.T) {
assert.Equal(t, clusters["devcluster"].Name, "devcluster")
assert.Equal(t, clusters["devcluster"].ZK, "192.168.33.7")
assert.Equal(t, clusters["devcluster"].SchedZKPath, "/aurora/scheduler")
assert.Equal(t, clusters["devcluster"].MesosZKPath, "/mesos")
assert.Equal(t, clusters["devcluster"].AuthMechanism, "UNAUTHENTICATED")
assert.Equal(t, clusters["devcluster"].AgentRunDir, "latest")
assert.Equal(t, clusters["devcluster"].AgentRoot, "/var/lib/mesos")

View file

@ -15,44 +15,31 @@
package realis
import (
"github.com/paypal/gorealis/gen-go/apache/aurora"
"github.com/aurora-scheduler/gorealis/v2/gen-go/apache/aurora"
)
// Container is an interface that defines a single function needed to create
// an Aurora container type. It exists because the code must support both Mesos
// and Docker containers.
type Container interface {
Build() *aurora.Container
}
// MesosContainer is a Mesos style container that can be used by Aurora Jobs.
type MesosContainer struct {
container *aurora.MesosContainer
}
// DockerContainer is a vanilla Docker style container that can be used by Aurora Jobs.
type DockerContainer struct {
container *aurora.DockerContainer
}
// NewDockerContainer creates a new Aurora compatible Docker container configuration.
func NewDockerContainer() DockerContainer {
return DockerContainer{container: aurora.NewDockerContainer()}
func NewDockerContainer() *DockerContainer {
return &DockerContainer{container: aurora.NewDockerContainer()}
}
// Build creates an Aurora container based upon the configuration provided.
func (c DockerContainer) Build() *aurora.Container {
func (c *DockerContainer) Build() *aurora.Container {
return &aurora.Container{Docker: c.container}
}
// Image adds the name of a Docker image to be used by the Job when running.
func (c DockerContainer) Image(image string) DockerContainer {
func (c *DockerContainer) Image(image string) *DockerContainer {
c.container.Image = image
return c
}
// AddParameter adds a parameter to be passed to Docker when the container is run.
func (c DockerContainer) AddParameter(name, value string) DockerContainer {
func (c *DockerContainer) AddParameter(name, value string) *DockerContainer {
c.container.Parameters = append(c.container.Parameters, &aurora.DockerParameter{
Name: name,
Value: value,
@ -60,18 +47,19 @@ func (c DockerContainer) AddParameter(name, value string) DockerContainer {
return c
}
// NewMesosContainer creates a Mesos style container to be configured and built for use by an Aurora Job.
func NewMesosContainer() MesosContainer {
return MesosContainer{container: aurora.NewMesosContainer()}
type MesosContainer struct {
container *aurora.MesosContainer
}
// Build creates a Mesos style Aurora container configuration to be passed on to the Aurora Job.
func (c MesosContainer) Build() *aurora.Container {
func NewMesosContainer() *MesosContainer {
return &MesosContainer{container: aurora.NewMesosContainer()}
}
func (c *MesosContainer) Build() *aurora.Container {
return &aurora.Container{Mesos: c.container}
}
// DockerImage configures the Mesos container to use a specific Docker image when being run.
func (c MesosContainer) DockerImage(name, tag string) MesosContainer {
func (c *MesosContainer) DockerImage(name, tag string) *MesosContainer {
if c.container.Image == nil {
c.container.Image = aurora.NewImage()
}
@ -80,12 +68,20 @@ func (c MesosContainer) DockerImage(name, tag string) MesosContainer {
return c
}
// AppcImage configures the Mesos container to use an image in the Appc format to run the container.
func (c MesosContainer) AppcImage(name, imageID string) MesosContainer {
func (c *MesosContainer) AppcImage(name, imageId string) *MesosContainer {
if c.container.Image == nil {
c.container.Image = aurora.NewImage()
}
c.container.Image.Appc = &aurora.AppcImage{Name: name, ImageId: imageID}
c.container.Image.Appc = &aurora.AppcImage{Name: name, ImageId: imageId}
return c
}
func (c *MesosContainer) AddVolume(hostPath, containerPath string, mode aurora.Mode) *MesosContainer {
c.container.Volumes = append(c.container.Volumes, &aurora.Volume{
HostPath: hostPath,
ContainerPath: containerPath,
Mode: mode})
return c
}

View file

@ -14,7 +14,7 @@ services:
ipv4_address: 192.168.33.2
master:
image: rdelvalle/mesos-master:1.6.2
image: quay.io/aurorascheduler/mesos-master:1.9.0
restart: on-failure
ports:
- "5050:5050"
@ -32,7 +32,7 @@ services:
- zk
agent-one:
image: rdelvalle/mesos-agent:1.6.2
image: quay.io/aurorascheduler/mesos-agent:1.9.0
pid: host
restart: on-failure
ports:
@ -41,10 +41,11 @@ services:
MESOS_MASTER: zk://192.168.33.2:2181/mesos
MESOS_CONTAINERIZERS: docker,mesos
MESOS_PORT: 5051
MESOS_HOSTNAME: localhost
MESOS_HOSTNAME: agent-one
MESOS_RESOURCES: ports(*):[11000-11999]
MESOS_SYSTEMD_ENABLE_SUPPORT: 'false'
MESOS_WORK_DIR: /tmp/mesos
MESOS_ATTRIBUTES: 'host:agent-one;rack:1;zone:west'
networks:
aurora_cluster:
ipv4_address: 192.168.33.4
@ -56,31 +57,57 @@ services:
- zk
agent-two:
image: rdelvalle/mesos-agent:1.6.2
image: quay.io/aurorascheduler/mesos-agent:1.9.0
pid: host
restart: on-failure
ports:
- "5061:5061"
- "5052:5051"
environment:
MESOS_MASTER: zk://192.168.33.2:2181/mesos
MESOS_CONTAINERIZERS: docker,mesos
MESOS_HOSTNAME: localhost
MESOS_PORT: 5061
MESOS_PORT: 5051
MESOS_HOSTNAME: agent-two
MESOS_RESOURCES: ports(*):[11000-11999]
MESOS_SYSTEMD_ENABLE_SUPPORT: 'false'
MESOS_WORK_DIR: /tmp/mesos
MESOS_ATTRIBUTES: 'host:agent-two;rack:2;zone:west'
networks:
aurora_cluster:
ipv4_address: 192.168.33.5
volumes:
- /sys/fs/cgroup:/sys/fs/cgroup
- /var/run/docker.sock:/var/run/docker.sock
- /sys/fs/cgroup:/sys/fs/cgroup
- /var/run/docker.sock:/var/run/docker.sock
depends_on:
- zk
- zk
agent-three:
image: quay.io/aurorascheduler/mesos-agent:1.9.0
pid: host
restart: on-failure
ports:
- "5053:5051"
environment:
MESOS_MASTER: zk://192.168.33.2:2181/mesos
MESOS_CONTAINERIZERS: docker,mesos
MESOS_PORT: 5051
MESOS_HOSTNAME: agent-three
MESOS_RESOURCES: ports(*):[11000-11999]
MESOS_SYSTEMD_ENABLE_SUPPORT: 'false'
MESOS_WORK_DIR: /tmp/mesos
MESOS_ATTRIBUTES: 'host:agent-three;rack:2;zone:west;dedicated:vagrant/bar'
networks:
aurora_cluster:
ipv4_address: 192.168.33.6
volumes:
- /sys/fs/cgroup:/sys/fs/cgroup
- /var/run/docker.sock:/var/run/docker.sock
depends_on:
- zk
aurora-one:
image: rdelvalle/aurora:0.22.0
image: quay.io/aurorascheduler/scheduler:0.25.0
pid: host
ports:
- "8081:8081"
@ -89,7 +116,14 @@ services:
CLUSTER_NAME: test-cluster
ZK_ENDPOINTS: "192.168.33.2:2181"
MESOS_MASTER: "zk://192.168.33.2:2181/mesos"
EXTRA_SCHEDULER_ARGS: "-min_required_instances_for_sla_check=1"
EXTRA_SCHEDULER_ARGS: >
-http_authentication_mechanism=BASIC
-shiro_realm_modules=INI_AUTHNZ
-shiro_ini_path=/etc/aurora/security.ini
-min_required_instances_for_sla_check=1
-thermos_executor_cpu=0.09
volumes:
- ./.aurora-config:/etc/aurora
networks:
aurora_cluster:
ipv4_address: 192.168.33.7

View file

@ -19,25 +19,18 @@ This also allows us to delete and recreate our development cluster very quickly.
To install docker-compose please follow the instructions for your platform
[here](https://docs.docker.com/compose/install/).
### Getting the source code
As of go 1.10.x, GOPATH is still relevant. This may change in the future but
for the sake of making development less error prone, it is suggested that the following
directories be created:
`$ git clone https://github.com/aurora-scheduler/gorealis`
`$ mkdir -p $GOPATH/src/github.com/paypal`
Inside of the newly cloned repo you may download dependencies to the local cache using go mod
And then clone the master branch into the newly created folder:
`$ cd $GOPATH/src/github.com/paypal; git clone git@github.com:paypal/gorealis.git`
Since we check in our vendor folder, gorealis no further set up is needed.
`$ go mod download`
### Bringing up the cluster
To develop gorealis, you will need a fully functioning Mesos cluster along with
Apache Aurora.
To develop gorealis, you will need a fully functioning Mesos cluster along with
the Aurora Scheduler.
In order to bring up our docker-compose set up execute the following command from the root
of the git repository:
@ -62,14 +55,14 @@ environment but not when running under MacOS. To run code involving the ZK leade
For example, running the tests in a container can be done through the following command from
the root of the git repository:
`$ docker run -t -v $(pwd):/go/src/github.com/paypal/gorealis --network gorealis_aurora_cluster golang:1.10.3-alpine go test github.com/paypal/gorealis`
`$ docker run -t -v $(pwd):/go/src/github.com/aurora-scheduler/gorealis --network gorealis_aurora_cluster golang:1.14.3-alpine go test github.com/paypal/gorealis`
Or
`$ ./runTestsMac.sh`
Alternatively, if an interactive shell is necessary, the following command may be used:
`$ docker run -it -v $(pwd):/go/src/github.com/paypal/gorealis --network gorealis_aurora_cluster golang:1.10.3-alpine /bin/sh`
`$ docker run -it -v $(pwd):/go/src/github.com/paypal/gorealis --network gorealis_aurora_cluster golang:1.14.3-alpine /bin/sh`
### Cleaning up the cluster
@ -85,6 +78,3 @@ Once development is done, the environment may be torn down by executing (from th
git directory):
`$ docker-compose down`

View file

@ -88,6 +88,12 @@ On Ubuntu, restarting the aurora-scheduler can be achieved by running the follow
$ sudo service aurora-scheduler restart
```
### Using a custom client
Pystachio does not yet support launching tasks using custom executors. Therefore, a custom
client must be used in order to launch tasks using a custom executor. In this case,
we will be using [gorealis](https://github.com/paypal/gorealis) to launch a task with
the compose executor on Aurora.
## Using [dce-go](https://github.com/paypal/dce-go)
Instead of manually configuring Aurora to run the docker-compose executor, one can follow the instructions provided [here](https://github.com/paypal/dce-go/blob/develop/docs/environment.md) to quickly create a DCE environment that would include mesos, aurora, golang1.7, docker, docker-compose and DCE installed.
@ -101,12 +107,80 @@ Mesos endpoint --> http://192.168.33.8:5050
### Installing Go
Follow the instructions at the official golang website: [golang.org/doc/install](https://golang.org/doc/install)
#### Linux
### Installing docker-compose
##### Ubuntu
Agents which will run dce-go will need docker-compose in order to sucessfully run the executor.
Instructions for installing docker-compose on various platforms may be found on Docker's webiste: [docs.docker.com/compose/install/](https://docs.docker.com/compose/install/)
###### Adding a PPA and install via apt-get
```
$ sudo add-apt-repository ppa:ubuntu-lxc/lxd-stable
$ sudo apt-get update
$ sudo apt-get install golang
```
###### Configuring the GOPATH
Configure the environment to be able to compile and run Go code.
```
$ mkdir $HOME/go
$ echo export GOPATH=$HOME/go >> $HOME/.bashrc
$ echo export GOROOT=/usr/lib/go >> $HOME/.bashrc
$ echo export PATH=$PATH:$GOPATH/bin >> $HOME/.bashrc
$ echo export PATH=$PATH:$GOROOT/bin >> $HOME/.bashrc
```
Finally we must reload the .bashrc configuration:
```
$ source $HOME/.bashrc
```
#### OS X
One way to install go on OS X is by using [Homebrew](http://brew.sh/)
##### Installing Homebrew
Run the following command from the terminal to install Hombrew:
```
$ /usr/bin/ruby -e "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/master/install)"
```
##### Installing Go using Hombrew
Run the following command from the terminal to install Go:
```
$ brew install go
```
##### Configuring the GOPATH
Configure the environment to be able to compile and run Go code.
```
$ mkdir $HOME/go
$ echo export GOPATH=$HOME/go >> $HOME/.profile
$ echo export GOROOT=/usr/local/opt/go/libexec >> $HOME/.profile
$ echo export PATH=$PATH:$GOPATH/bin >> $HOME/.profile
$ echo export PATH=$PATH:$GOROOT/bin >> $HOME/.profile
```
Finally we must reload the .profile configuration:
```
$ source $HOME/.profile
```
#### Windows
Download and run the msi installer from https://golang.org/dl/
## Installing Docker Compose (if manually configured Aurora)
To show Aurora's new multi executor feature, we need to use at least one custom executor.
In this case we will be using the [docker-compose-executor](https://github.com/mesos/docker-compose-executor).
In order to run the docker-compose executor, each agent must have docker-compose installed on it.
This can be done using pip:
```
$ sudo pip install docker-compose
```
## Downloading gorealis
Finally, we must get `gorealis` using the `go get` command:
@ -118,7 +192,7 @@ go get github.com/paypal/gorealis
# Creating Aurora Jobs
## Creating a thermos job
To demonstrate that we are able to run jobs using different executors on the
To demonstrate that we are able to run jobs using different executors on the
same scheduler, we'll first launch a thermos job using the default Aurora Client.
We can use a sample job for this:
@ -173,6 +247,9 @@ job = realis.NewJob().
RAM(64).
Disk(100).
IsService(false).
Production(false).
Tier("preemptible").
Priority(0).
InstanceCount(1).
AddPorts(1).
AddLabel("fileName", "sample-app/docker-compose.yml").
@ -185,8 +262,8 @@ go run $GOPATH/src/github.com/paypal/gorealis/examples/client.go -executor=compo
```
If everything went according to plan, a new job will be shown in the Aurora UI.
We can further investigate inside the Mesos task sandbox. Inside the sandbox, under
the sample-app folder, we can find a docker-compose.yml-generated.yml. If we inspect this file,
We can further investigate inside the Mesos task sandbox. Inside the sandbox, under
the sample-app folder, we can find a docker-compose.yml-generated.yml. If we inspect this file,
we can find the port at which we can find the web server we launched.
Under Web->Ports, we find the port Mesos allocated. We can then navigate to:
@ -195,10 +272,10 @@ Under Web->Ports, we find the port Mesos allocated. We can then navigate to:
A message from the executor should greet us.
## Creating a Thermos job using gorealis
It is also possible to create a thermos job using gorealis. To do this, however,
It is also possible to create a thermos job using gorealis. To do this, however,
a thermos payload is required. A thermos payload consists of a JSON blob that details
the entire task as it exists inside the Aurora Scheduler. *Creating the blob is unfortunately
out of the scope of what gorealis does*, so a thermos payload must be generated beforehand or
out of the scope of what gorealis does*, so a thermos payload must be generated beforehand or
retrieved from the structdump of an existing task for testing purposes.
A sample thermos JSON payload may be found [here](../examples/thermos_payload.json) in the examples folder.
@ -217,6 +294,9 @@ job = realis.NewJob().
RAM(64).
Disk(100).
IsService(true).
Production(false).
Tier("preemptible").
Priority(0).
InstanceCount(1).
AddPorts(1)
```

View file

@ -25,6 +25,9 @@ job = realis.NewJob().
RAM(64).
Disk(100).
IsService(false).
Production(false).
Tier("preemptible").
Priority(0).
InstanceCount(1).
AddPorts(1).
AddLabel("fileName", "sample-app/docker-compose.yml").
@ -57,19 +60,4 @@ updateJob := realis.NewUpdateJob(job)
updateJob.InstanceCount(1)
updateJob.Ram(128)
msg, err := r.UpdateJob(updateJob, "")
```
* Handling a timeout scenario:
When sending an API call to Aurora, the call may timeout at the client side.
This means that the time limit has been reached while waiting for the scheduler
to reply. In such a case it is recommended that the timeout is increased through
the use of the `realis.TimeoutMS()` option.
As these timeouts cannot be totally avoided, there exists a mechanism to mitigate such
scenarios. The `StartJobUpdate` and `CreateService` API will return an error that
implements the Timeout interface.
An error can be checked to see if it is a Timeout error by using the `realis.IsTimeout()`
function.
```

View file

@ -1,6 +1,6 @@
# Using the Sample client
## Usage:
## Usage:
```
Usage of ./client:
-cluster string
@ -22,28 +22,25 @@ Usage of ./client:
```
## Sample commands:
These commands are set to run on a vagrant box. To be able to run the docker compose
executor examples, the vagrant box must be configured properly to use the docker compose executor.
### Thermos
#### Creating a Thermos job
```
$ cd $GOPATH/src/github.com/paypal/gorealis/examples
$ go run client.go -executor=thermos -url=http://192.168.33.7:8081 -cmd=create
$ go run examples/client.go -url=http://localhost:8081 -executor=thermos -cmd=create
```
#### Kill a Thermos job
```
$ go run $GOPATH/src/github.com/paypal/gorealis/examples/client.go -executor=thermos -url=http://192.168.33.7:8081 -cmd=kill
$ go run examples/client.go -url=http://localhost:8081 -executor=thermos -cmd=kill
```
### Docker Compose executor (custom executor)
#### Creating Docker Compose executor job
```
$ go run $GOPATH/src/github.com/paypal/gorealis/examples/client.go -executor=compose -url=http://192.168.33.7:8081 -cmd=create
$ go run examples/client.go -url=http://192.168.33.7:8081 -executor=compose -cmd=create
```
#### Kill a Docker Compose executor job
```
$ go run $GOPATH/src/github.com/paypal/gorealis/examples/client.go -executor=compose -url=http://192.168.33.7:8081 -cmd=kill
$ go run examples/client.go -url=http://192.168.33.7:8081 -executor=compose -cmd=kill
```

View file

@ -17,14 +17,12 @@ package realis
// Using a pattern described by Dave Cheney to differentiate errors
// https://dave.cheney.net/2016/04/27/dont-just-check-errors-handle-them-gracefully
// Timeout errors are returned when a function is unable to continue executing due
// Timedout errors are returned when a function is unable to continue executing due
// to a time constraint or meeting a set number of retries.
type timeout interface {
Timedout() bool
}
// IsTimeout returns true if the error being passed as an argument implements the Timeout interface
// and the Timedout function returns true.
func IsTimeout(err error) bool {
temp, ok := err.(timeout)
return ok && temp.Timedout()
@ -63,42 +61,41 @@ func (r *retryErr) RetryCount() int {
return r.retryCount
}
// ToRetryCount is a helper function for testing verification to avoid whitebox testing
// Helper function for testing verification to avoid whitebox testing
// as well as keeping retryErr as a private.
// Should NOT be used under any other context.
func ToRetryCount(err error) *retryErr {
if retryErr, ok := err.(*retryErr); ok {
return retryErr
} else {
return nil
}
return nil
}
func newRetryError(err error, retryCount int) *retryErr {
return &retryErr{error: err, timedout: true, retryCount: retryCount}
}
// Temporary errors indicate that the action may or should be retried.
// Temporary errors indicate that the action may and should be retried.
type temporary interface {
Temporary() bool
}
// IsTemporary indicates whether the error passed in as an argument implements the temporary interface
// and if the Temporary function returns true.
func IsTemporary(err error) bool {
temp, ok := err.(temporary)
return ok && temp.Temporary()
}
type temporaryErr struct {
type TemporaryErr struct {
error
temporary bool
}
func (t *temporaryErr) Temporary() bool {
func (t *TemporaryErr) Temporary() bool {
return t.temporary
}
// NewTemporaryError creates a new error which satisfies the Temporary interface.
func NewTemporaryError(err error) *temporaryErr {
return &temporaryErr{error: err, temporary: true}
// Retrying after receiving this error is advised
func NewTemporaryError(err error) *TemporaryErr {
return &TemporaryErr{error: err, temporary: true}
}

View file

@ -17,24 +17,22 @@ package main
import (
"flag"
"fmt"
"io/ioutil"
"log"
"strings"
"time"
realis "github.com/paypal/gorealis"
"github.com/paypal/gorealis/gen-go/apache/aurora"
"github.com/paypal/gorealis/response"
realis "github.com/aurora-scheduler/gorealis/v2"
"github.com/aurora-scheduler/gorealis/v2/gen-go/apache/aurora"
)
var cmd, executor, url, clustersConfig, clusterName, updateId, username, password, zkUrl, hostList, role string
var caCertsPath string
var clientKey, clientCert string
var ConnectionTimeout = 20000
var ConnectionTimeout = 20 * time.Second
func init() {
flag.StringVar(&cmd, "cmd", "", "Job request type to send to Aurora Scheduler")
flag.StringVar(&cmd, "cmd", "", "Aurora Job request type to send to Aurora Scheduler")
flag.StringVar(&executor, "executor", "thermos", "Executor to use")
flag.StringVar(&url, "url", "", "URL at which the Aurora Scheduler exists as [url]:[port]")
flag.StringVar(&clustersConfig, "clusters", "", "Location of the clusters.json file used by aurora.")
@ -74,15 +72,14 @@ func init() {
func main() {
var job realis.Job
var job *realis.AuroraJob
var err error
var monitor *realis.Monitor
var r realis.Realis
var r *realis.Client
clientOptions := []realis.ClientOption{
realis.BasicAuth(username, password),
realis.ThriftJSON(),
realis.TimeoutMS(ConnectionTimeout),
realis.Timeout(ConnectionTimeout),
realis.BackOff(realis.Backoff{
Steps: 2,
Duration: 10 * time.Second,
@ -100,39 +97,39 @@ func main() {
}
if caCertsPath != "" {
clientOptions = append(clientOptions, realis.Certspath(caCertsPath))
clientOptions = append(clientOptions, realis.CertsPath(caCertsPath))
}
if clientKey != "" && clientCert != "" {
clientOptions = append(clientOptions, realis.ClientCerts(clientKey, clientCert))
}
r, err = realis.NewRealisClient(clientOptions...)
r, err = realis.NewClient(clientOptions...)
if err != nil {
log.Fatalln(err)
}
monitor = &realis.Monitor{r}
defer r.Close()
switch executor {
case "thermos":
payload, err := ioutil.ReadFile("examples/thermos_payload.json")
if err != nil {
log.Fatalln("Error reading json config file: ", err)
}
thermosExec := realis.ThermosExecutor{}
thermosExec.AddProcess(realis.NewThermosProcess("boostrap", "echo bootsrapping")).
AddProcess(realis.NewThermosProcess("hello_gorealis", "while true; do echo hello world from gorealis; sleep 10; done"))
job = realis.NewJob().
Environment("prod").
Role("vagrant").
Name("hello_world_from_gorealis").
ExecutorName(aurora.AURORA_EXECUTOR_NAME).
ExecutorData(string(payload)).
CPU(1).
RAM(64).
Disk(100).
IsService(true).
Production(false).
Tier("preemptible").
Priority(0).
InstanceCount(1).
AddPorts(1)
AddPorts(1).
ThermosExecutor(thermosExec)
case "compose":
job = realis.NewJob().
Environment("prod").
@ -144,6 +141,9 @@ func main() {
RAM(512).
Disk(100).
IsService(true).
Production(false).
Tier("preemptible").
Priority(0).
InstanceCount(1).
AddPorts(4).
AddLabel("fileName", "sample-app/docker-compose.yml").
@ -157,6 +157,9 @@ func main() {
RAM(64).
Disk(100).
IsService(true).
Production(false).
Tier("preemptible").
Priority(0).
InstanceCount(1).
AddPorts(1)
default:
@ -166,14 +169,13 @@ func main() {
switch cmd {
case "create":
fmt.Println("Creating job")
resp, err := r.CreateJob(job)
err := r.CreateJob(job)
if err != nil {
log.Fatalln(err)
}
fmt.Println(resp.String())
if ok, mErr := monitor.Instances(job.JobKey(), job.GetInstanceCount(), 5, 50); !ok || mErr != nil {
_, err := r.KillJob(job.JobKey())
if ok, mErr := r.MonitorInstances(job.JobKey(), job.GetInstanceCount(), 5*time.Second, 50*time.Second); !ok || mErr != nil {
err := r.KillJob(job.JobKey())
if err != nil {
log.Fatalln(err)
}
@ -183,18 +185,17 @@ func main() {
case "createService":
// Create a service with three instances using the update API instead of the createJob API
fmt.Println("Creating service")
settings := realis.NewUpdateSettings()
job.InstanceCount(3)
resp, result, err := r.CreateService(job, settings)
settings := realis.JobUpdateFromAuroraTask(job.AuroraTask()).InstanceCount(3)
result, err := r.CreateService(settings)
if err != nil {
log.Println("error: ", err)
log.Fatal("response: ", resp.String())
log.Fatal("error: ", err)
}
fmt.Println(result.String())
if ok, mErr := monitor.JobUpdate(*result.GetKey(), 5, 180); !ok || mErr != nil {
_, err := r.AbortJobUpdate(*result.GetKey(), "Monitor timed out")
_, err = r.KillJob(job.JobKey())
if ok, mErr := r.MonitorJobUpdate(*result.GetKey(), 5*time.Second, 180*time.Second); !ok || mErr != nil {
err := r.AbortJobUpdate(*result.GetKey(), "Monitor timed out")
err = r.KillJob(job.JobKey())
if err != nil {
log.Fatal(err)
}
@ -205,14 +206,13 @@ func main() {
fmt.Println("Creating a docker based job")
container := realis.NewDockerContainer().Image("python:2.7").AddParameter("network", "host")
job.Container(container)
resp, err := r.CreateJob(job)
err := r.CreateJob(job)
if err != nil {
log.Fatal(err)
}
fmt.Println(resp.String())
if ok, err := monitor.Instances(job.JobKey(), job.GetInstanceCount(), 10, 300); !ok || err != nil {
_, err := r.KillJob(job.JobKey())
if ok, err := r.MonitorInstances(job.JobKey(), job.GetInstanceCount(), 10*time.Second, 300*time.Second); !ok || err != nil {
err := r.KillJob(job.JobKey())
if err != nil {
log.Fatal(err)
}
@ -222,14 +222,13 @@ func main() {
fmt.Println("Creating a docker based job")
container := realis.NewMesosContainer().DockerImage("python", "2.7")
job.Container(container)
resp, err := r.CreateJob(job)
err := r.CreateJob(job)
if err != nil {
log.Fatal(err)
}
fmt.Println(resp.String())
if ok, err := monitor.Instances(job.JobKey(), job.GetInstanceCount(), 10, 300); !ok || err != nil {
_, err := r.KillJob(job.JobKey())
if ok, err := r.MonitorInstances(job.JobKey(), job.GetInstanceCount(), 10*time.Second, 300*time.Second); !ok || err != nil {
err := r.KillJob(job.JobKey())
if err != nil {
log.Fatal(err)
}
@ -240,50 +239,44 @@ func main() {
// Cron config
job.CronSchedule("* * * * *")
job.IsService(false)
resp, err := r.ScheduleCronJob(job)
err := r.ScheduleCronJob(job)
if err != nil {
log.Fatal(err)
}
fmt.Println(resp.String())
case "startCron":
fmt.Println("Starting a Cron job")
resp, err := r.StartCronJob(job.JobKey())
err := r.StartCronJob(job.JobKey())
if err != nil {
log.Fatal(err)
}
fmt.Println(resp.String())
case "descheduleCron":
fmt.Println("Descheduling a Cron job")
resp, err := r.DescheduleCronJob(job.JobKey())
err := r.DescheduleCronJob(job.JobKey())
if err != nil {
log.Fatal(err)
}
fmt.Println(resp.String())
case "kill":
fmt.Println("Killing job")
resp, err := r.KillJob(job.JobKey())
err := r.KillJob(job.JobKey())
if err != nil {
log.Fatal(err)
}
if ok, err := monitor.Instances(job.JobKey(), 0, 5, 50); !ok || err != nil {
if ok, err := r.MonitorInstances(job.JobKey(), 0, 5*time.Second, 50*time.Second); !ok || err != nil {
log.Fatal("Unable to kill all instances of job")
}
fmt.Println(resp.String())
case "restart":
fmt.Println("Restarting job")
resp, err := r.RestartJob(job.JobKey())
err := r.RestartJob(job.JobKey())
if err != nil {
log.Fatal(err)
}
fmt.Println(resp.String())
case "liveCount":
fmt.Println("Getting instance count")
@ -302,106 +295,110 @@ func main() {
log.Fatal(err)
}
fmt.Println("Number of live instances: ", len(live))
fmt.Println("Active instances: ", live)
case "flexUp":
fmt.Println("Flexing up job")
numOfInstances := int32(4)
numOfInstances := 4
live, err := r.GetInstanceIds(job.JobKey(), aurora.ACTIVE_STATES)
if err != nil {
log.Fatal(err)
}
currInstances := int32(len(live))
currInstances := len(live)
fmt.Println("Current num of instances: ", currInstances)
resp, err := r.AddInstances(aurora.InstanceKey{
JobKey: job.JobKey(),
key := job.JobKey()
err = r.AddInstances(aurora.InstanceKey{
JobKey: &key,
InstanceId: live[0],
},
numOfInstances)
int32(numOfInstances))
if err != nil {
log.Fatal(err)
}
if ok, err := monitor.Instances(job.JobKey(), currInstances+numOfInstances, 5, 50); !ok || err != nil {
if ok, err := r.MonitorInstances(job.JobKey(), int32(currInstances+numOfInstances), 5*time.Second, 50*time.Second); !ok || err != nil {
fmt.Println("Flexing up failed")
}
fmt.Println(resp.String())
case "flexDown":
fmt.Println("Flexing down job")
numOfInstances := int32(2)
numOfInstances := 2
live, err := r.GetInstanceIds(job.JobKey(), aurora.ACTIVE_STATES)
if err != nil {
log.Fatal(err)
}
currInstances := int32(len(live))
currInstances := len(live)
fmt.Println("Current num of instances: ", currInstances)
resp, err := r.RemoveInstances(job.JobKey(), numOfInstances)
err = r.RemoveInstances(job.JobKey(), numOfInstances)
if err != nil {
log.Fatal(err)
}
if ok, err := monitor.Instances(job.JobKey(), currInstances-numOfInstances, 5, 100); !ok || err != nil {
if ok, err := r.MonitorInstances(job.JobKey(), int32(currInstances-numOfInstances), 5*time.Second, 100*time.Second); !ok || err != nil {
fmt.Println("flexDown failed")
}
fmt.Println(resp.String())
case "update":
fmt.Println("Updating a job with with more RAM and to 5 instances")
live, err := r.GetInstanceIds(job.JobKey(), aurora.ACTIVE_STATES)
if err != nil {
log.Fatal(err)
}
key := job.JobKey()
taskConfig, err := r.FetchTaskConfig(aurora.InstanceKey{
JobKey: job.JobKey(),
JobKey: &key,
InstanceId: live[0],
})
if err != nil {
log.Fatal(err)
}
updateJob := realis.NewDefaultUpdateJob(taskConfig)
updateJob.InstanceCount(5).RAM(128)
updateJob := realis.JobUpdateFromConfig(taskConfig).InstanceCount(5).RAM(128)
resp, err := r.StartJobUpdate(updateJob, "")
result, err := r.StartJobUpdate(updateJob, "")
if err != nil {
log.Fatal(err)
}
jobUpdateKey := response.JobUpdateKey(resp)
monitor.JobUpdate(*jobUpdateKey, 5, 500)
jobUpdateKey := result.GetKey()
_, err = r.MonitorJobUpdate(*jobUpdateKey, 5*time.Second, 6*time.Minute)
if err != nil {
log.Fatal(err)
}
case "pauseJobUpdate":
resp, err := r.PauseJobUpdate(&aurora.JobUpdateKey{
Job: job.JobKey(),
key := job.JobKey()
err := r.PauseJobUpdate(&aurora.JobUpdateKey{
Job: &key,
ID: updateId,
}, "")
if err != nil {
log.Fatal(err)
}
fmt.Println("PauseJobUpdate response: ", resp.String())
case "resumeJobUpdate":
resp, err := r.ResumeJobUpdate(&aurora.JobUpdateKey{
Job: job.JobKey(),
key := job.JobKey()
err := r.ResumeJobUpdate(aurora.JobUpdateKey{
Job: &key,
ID: updateId,
}, "")
if err != nil {
log.Fatal(err)
}
fmt.Println("ResumeJobUpdate response: ", resp.String())
case "pulseJobUpdate":
resp, err := r.PulseJobUpdate(&aurora.JobUpdateKey{
Job: job.JobKey(),
key := job.JobKey()
resp, err := r.PulseJobUpdate(aurora.JobUpdateKey{
Job: &key,
ID: updateId,
})
if err != nil {
@ -411,9 +408,10 @@ func main() {
fmt.Println("PulseJobUpdate response: ", resp.String())
case "updateDetails":
resp, err := r.JobUpdateDetails(aurora.JobUpdateQuery{
key := job.JobKey()
result, err := r.JobUpdateDetails(aurora.JobUpdateQuery{
Key: &aurora.JobUpdateKey{
Job: job.JobKey(),
Job: &key,
ID: updateId,
},
Limit: 1,
@ -423,12 +421,13 @@ func main() {
log.Fatal(err)
}
fmt.Println(response.JobUpdateDetails(resp))
fmt.Println(result)
case "abortUpdate":
fmt.Println("Abort update")
resp, err := r.AbortJobUpdate(aurora.JobUpdateKey{
Job: job.JobKey(),
key := job.JobKey()
err := r.AbortJobUpdate(aurora.JobUpdateKey{
Job: &key,
ID: updateId,
},
"")
@ -436,12 +435,12 @@ func main() {
if err != nil {
log.Fatal(err)
}
fmt.Println(resp.String())
case "rollbackUpdate":
fmt.Println("Abort update")
resp, err := r.RollbackJobUpdate(aurora.JobUpdateKey{
Job: job.JobKey(),
key := job.JobKey()
err := r.RollbackJobUpdate(aurora.JobUpdateKey{
Job: &key,
ID: updateId,
},
"")
@ -449,14 +448,6 @@ func main() {
if err != nil {
log.Fatal(err)
}
fmt.Println(resp.String())
case "variableBatchStep":
step, err := r.VariableBatchStep(aurora.JobUpdateKey{Job: job.JobKey(), ID: updateId})
if err != nil {
log.Fatal(err)
}
fmt.Println(step)
case "taskConfig":
fmt.Println("Getting job info")
@ -465,8 +456,9 @@ func main() {
log.Fatal(err)
}
key := job.JobKey()
config, err := r.FetchTaskConfig(aurora.InstanceKey{
JobKey: job.JobKey(),
JobKey: &key,
InstanceId: live[0],
})
@ -478,9 +470,10 @@ func main() {
case "updatesummary":
fmt.Println("Getting job update summary")
key := job.JobKey()
jobquery := &aurora.JobUpdateQuery{
Role: &job.JobKey().Role,
JobKey: job.JobKey(),
Role: &key.Role,
JobKey: &key,
}
updatesummary, err := r.GetJobUpdateSummaries(jobquery)
if err != nil {
@ -491,10 +484,11 @@ func main() {
case "taskStatus":
fmt.Println("Getting task status")
key := job.JobKey()
taskQ := &aurora.TaskQuery{
Role: &job.JobKey().Role,
Environment: &job.JobKey().Environment,
JobName: &job.JobKey().Name,
Role: &key.Role,
Environment: &key.Environment,
JobName: &key.Name,
}
tasks, err := r.GetTaskStatus(taskQ)
if err != nil {
@ -506,10 +500,11 @@ func main() {
case "tasksWithoutConfig":
fmt.Println("Getting task status")
key := job.JobKey()
taskQ := &aurora.TaskQuery{
Role: &job.JobKey().Role,
Environment: &job.JobKey().Environment,
JobName: &job.JobKey().Name,
Role: &key.Role,
Environment: &key.Environment,
JobName: &key.Name,
}
tasks, err := r.GetTasksWithoutConfigs(taskQ)
if err != nil {
@ -525,17 +520,17 @@ func main() {
log.Fatal("No hosts specified to drain")
}
hosts := strings.Split(hostList, ",")
_, result, err := r.DrainHosts(hosts...)
_, err := r.DrainHosts(hosts...)
if err != nil {
log.Fatalf("error: %+v\n", err.Error())
}
// Monitor change to DRAINING and DRAINED mode
hostResult, err := monitor.HostMaintenance(
hostResult, err := r.MonitorHostMaintenance(
hosts,
[]aurora.MaintenanceMode{aurora.MaintenanceMode_DRAINED, aurora.MaintenanceMode_DRAINING},
5,
10)
5*time.Second,
10*time.Second)
if err != nil {
for host, ok := range hostResult {
if !ok {
@ -545,8 +540,6 @@ func main() {
log.Fatalf("error: %+v\n", err.Error())
}
fmt.Print(result.String())
case "SLADrainHosts":
fmt.Println("Setting hosts to DRAINING using SLA aware draining")
if hostList == "" {
@ -556,17 +549,17 @@ func main() {
policy := aurora.SlaPolicy{PercentageSlaPolicy: &aurora.PercentageSlaPolicy{Percentage: 50.0}}
result, err := r.SLADrainHosts(&policy, 30, hosts...)
_, err := r.SLADrainHosts(&policy, 30, hosts...)
if err != nil {
log.Fatalf("error: %+v\n", err.Error())
}
// Monitor change to DRAINING and DRAINED mode
hostResult, err := monitor.HostMaintenance(
hostResult, err := r.MonitorHostMaintenance(
hosts,
[]aurora.MaintenanceMode{aurora.MaintenanceMode_DRAINED, aurora.MaintenanceMode_DRAINING},
5,
10)
5*time.Second,
10*time.Second)
if err != nil {
for host, ok := range hostResult {
if !ok {
@ -576,25 +569,23 @@ func main() {
log.Fatalf("error: %+v\n", err.Error())
}
fmt.Print(result.String())
case "endMaintenance":
fmt.Println("Setting hosts to ACTIVE")
if hostList == "" {
log.Fatal("No hosts specified to drain")
}
hosts := strings.Split(hostList, ",")
_, result, err := r.EndMaintenance(hosts...)
_, err := r.EndMaintenance(hosts...)
if err != nil {
log.Fatalf("error: %+v\n", err.Error())
}
// Monitor change to DRAINING and DRAINED mode
hostResult, err := monitor.HostMaintenance(
hostResult, err := r.MonitorHostMaintenance(
hosts,
[]aurora.MaintenanceMode{aurora.MaintenanceMode_NONE},
5,
10)
5*time.Second,
10*time.Second)
if err != nil {
for host, ok := range hostResult {
if !ok {
@ -604,14 +595,13 @@ func main() {
log.Fatalf("error: %+v\n", err.Error())
}
fmt.Print(result.String())
case "getPendingReasons":
fmt.Println("Getting pending reasons")
key := job.JobKey()
taskQ := &aurora.TaskQuery{
Role: &job.JobKey().Role,
Environment: &job.JobKey().Environment,
JobName: &job.JobKey().Name,
Role: &key.Role,
Environment: &key.Environment,
JobName: &key.Name,
}
reasons, err := r.GetPendingReason(taskQ)
if err != nil {
@ -623,7 +613,7 @@ func main() {
case "getJobs":
fmt.Println("GetJobs...role: ", role)
_, result, err := r.GetJobs(role)
result, err := r.GetJobs(role)
if err != nil {
log.Fatalf("error: %+v\n", err.Error())
}

View file

@ -2,6 +2,7 @@
"name": "devcluster",
"zk": "192.168.33.7",
"scheduler_zk_path": "/aurora/scheduler",
"mesos_zk_path": "/mesos",
"auth_mechanism": "UNAUTHENTICATED",
"slave_run_directory": "latest",
"slave_root": "/var/lib/mesos"

View file

@ -23,8 +23,8 @@ import (
"os"
"time"
realis "github.com/paypal/gorealis"
"github.com/paypal/gorealis/gen-go/apache/aurora"
realis "github.com/aurora-scheduler/gorealis/v2"
"github.com/aurora-scheduler/gorealis/v2/gen-go/apache/aurora"
"github.com/pkg/errors"
)
@ -125,7 +125,7 @@ func init() {
}
}
func CreateRealisClient(config *Config) (realis.Realis, error) {
func CreateRealisClient(config *Config) (*realis.Client, error) {
var transportOption realis.ClientOption
// Configuring transport protocol. If not transport is provided, then using JSON as the
// default transport protocol.
@ -157,7 +157,7 @@ func CreateRealisClient(config *Config) (realis.Realis, error) {
clientOptions = append(clientOptions, realis.Debug())
}
return realis.NewRealisClient(clientOptions...)
return realis.NewClient(clientOptions...)
}
func main() {
@ -165,7 +165,6 @@ func main() {
fmt.Println(clientCreationErr)
os.Exit(1)
} else {
monitor := &realis.Monitor{Client: r}
defer r.Close()
uris := job.URIs
labels := job.Labels
@ -178,6 +177,8 @@ func main() {
RAM(job.RAM).
Disk(job.Disk).
IsService(job.Service).
Tier("preemptible").
Priority(0).
InstanceCount(job.Instances).
AddPorts(job.Ports)
@ -205,20 +206,18 @@ func main() {
}
fmt.Println("Creating Job...")
if resp, jobCreationErr := r.CreateJob(auroraJob); jobCreationErr != nil {
if jobCreationErr := r.CreateJob(auroraJob); jobCreationErr != nil {
fmt.Println("Error creating Aurora job: ", jobCreationErr)
os.Exit(1)
} else {
if resp.ResponseCode == aurora.ResponseCode_OK {
if ok, monitorErr := monitor.Instances(auroraJob.JobKey(), auroraJob.GetInstanceCount(), 5, 50); !ok || monitorErr != nil {
if _, jobErr := r.KillJob(auroraJob.JobKey()); jobErr !=
nil {
fmt.Println(jobErr)
os.Exit(1)
} else {
fmt.Println("ok: ", ok)
fmt.Println("jobErr: ", jobErr)
}
if ok, monitorErr := r.MonitorInstances(auroraJob.JobKey(), auroraJob.GetInstanceCount(), 5, 50); !ok || monitorErr != nil {
if jobErr := r.KillJob(auroraJob.JobKey()); jobErr !=
nil {
fmt.Println(jobErr)
os.Exit(1)
} else {
fmt.Println("ok: ", ok)
fmt.Println("jobErr: ", jobErr)
}
}
}

View file

@ -1,62 +0,0 @@
{
"environment": "prod",
"health_check_config": {
"initial_interval_secs": 15.0,
"health_checker": {
"http": {
"expected_response_code": 0,
"endpoint": "/health",
"expected_response": "ok"
}
},
"interval_secs": 10.0,
"timeout_secs": 1.0,
"max_consecutive_failures": 0
},
"name": "hello_world_from_gorealis",
"service": false,
"max_task_failures": 1,
"cron_collision_policy": "KILL_EXISTING",
"enable_hooks": false,
"cluster": "devcluster",
"task": {
"processes": [
{
"daemon": false,
"name": "hello",
"ephemeral": false,
"max_failures": 1,
"min_duration": 5,
"cmdline": "\n while true; do\n echo hello world from gorealis\n sleep 10\n done\n ",
"final": false
}
],
"name": "hello",
"finalization_wait": 30,
"max_failures": 1,
"max_concurrency": 0,
"resources": {
"gpu": 0,
"disk": 134217728,
"ram": 134217728,
"cpu": 1.0
},
"constraints": [
{
"order": [
"hello"
]
}
]
},
"production": false,
"role": "vagrant",
"lifecycle": {
"http": {
"graceful_shutdown_endpoint": "/quitquitquit",
"port": "health",
"shutdown_endpoint": "/abortabortabort"
}
},
"priority": 0
}

View file

@ -0,0 +1,28 @@
{
"task": {
"processes": [
{
"daemon": false,
"name": "hello",
"ephemeral": false,
"max_failures": 1,
"min_duration": 5,
"cmdline": "\n while true; do\n echo hello world from gorealis\n sleep 10\n done\n ",
"final": false
}
],
"resources": {
"gpu": 0,
"disk": 134217728,
"ram": 134217728,
"cpu": 1.1
},
"constraints": [
{
"order": [
"hello"
]
}
]
}
}

View file

@ -1,5 +1,4 @@
// Autogenerated by Thrift Compiler (0.12.0)
// DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING
// Code generated by Thrift Compiler (0.14.0). DO NOT EDIT.
package aurora

View file

@ -1,13 +1,12 @@
// Autogenerated by Thrift Compiler (0.12.0)
// DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING
// Code generated by Thrift Compiler (0.14.0). DO NOT EDIT.
package aurora
import (
import(
"bytes"
"context"
"reflect"
"fmt"
"time"
"github.com/apache/thrift/lib/go/thrift"
)
@ -15,7 +14,7 @@ import (
var _ = thrift.ZERO
var _ = fmt.Printf
var _ = context.Background
var _ = reflect.DeepEqual
var _ = time.Now
var _ = bytes.Equal
const AURORA_EXECUTOR_NAME = "AuroraExecutor"

File diff suppressed because it is too large Load diff

View file

@ -1,22 +1,22 @@
// Autogenerated by Thrift Compiler (0.12.0)
// DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING
// Code generated by Thrift Compiler (0.14.0). DO NOT EDIT.
package main
import (
"context"
"flag"
"fmt"
"math"
"net"
"net/url"
"os"
"strconv"
"strings"
"github.com/apache/thrift/lib/go/thrift"
"apache/aurora"
"context"
"flag"
"fmt"
"math"
"net"
"net/url"
"os"
"strconv"
"strings"
"github.com/apache/thrift/lib/go/thrift"
"apache/aurora"
)
var _ = aurora.GoUnusedProtection__
func Usage() {
fmt.Fprintln(os.Stderr, "Usage of ", os.Args[0], " [-h host:port] [-u url] [-f[ramed]] function [arg1 [arg2...]]:")
@ -175,19 +175,19 @@ func main() {
fmt.Fprintln(os.Stderr, "CreateJob requires 1 args")
flag.Usage()
}
arg163 := flag.Arg(1)
mbTrans164 := thrift.NewTMemoryBufferLen(len(arg163))
defer mbTrans164.Close()
_, err165 := mbTrans164.WriteString(arg163)
if err165 != nil {
arg213 := flag.Arg(1)
mbTrans214 := thrift.NewTMemoryBufferLen(len(arg213))
defer mbTrans214.Close()
_, err215 := mbTrans214.WriteString(arg213)
if err215 != nil {
Usage()
return
}
factory166 := thrift.NewTJSONProtocolFactory()
jsProt167 := factory166.GetProtocol(mbTrans164)
factory216 := thrift.NewTJSONProtocolFactory()
jsProt217 := factory216.GetProtocol(mbTrans214)
argvalue0 := aurora.NewJobConfiguration()
err168 := argvalue0.Read(jsProt167)
if err168 != nil {
err218 := argvalue0.Read(context.Background(), jsProt217)
if err218 != nil {
Usage()
return
}
@ -200,19 +200,19 @@ func main() {
fmt.Fprintln(os.Stderr, "ScheduleCronJob requires 1 args")
flag.Usage()
}
arg169 := flag.Arg(1)
mbTrans170 := thrift.NewTMemoryBufferLen(len(arg169))
defer mbTrans170.Close()
_, err171 := mbTrans170.WriteString(arg169)
if err171 != nil {
arg219 := flag.Arg(1)
mbTrans220 := thrift.NewTMemoryBufferLen(len(arg219))
defer mbTrans220.Close()
_, err221 := mbTrans220.WriteString(arg219)
if err221 != nil {
Usage()
return
}
factory172 := thrift.NewTJSONProtocolFactory()
jsProt173 := factory172.GetProtocol(mbTrans170)
factory222 := thrift.NewTJSONProtocolFactory()
jsProt223 := factory222.GetProtocol(mbTrans220)
argvalue0 := aurora.NewJobConfiguration()
err174 := argvalue0.Read(jsProt173)
if err174 != nil {
err224 := argvalue0.Read(context.Background(), jsProt223)
if err224 != nil {
Usage()
return
}
@ -225,19 +225,19 @@ func main() {
fmt.Fprintln(os.Stderr, "DescheduleCronJob requires 1 args")
flag.Usage()
}
arg175 := flag.Arg(1)
mbTrans176 := thrift.NewTMemoryBufferLen(len(arg175))
defer mbTrans176.Close()
_, err177 := mbTrans176.WriteString(arg175)
if err177 != nil {
arg225 := flag.Arg(1)
mbTrans226 := thrift.NewTMemoryBufferLen(len(arg225))
defer mbTrans226.Close()
_, err227 := mbTrans226.WriteString(arg225)
if err227 != nil {
Usage()
return
}
factory178 := thrift.NewTJSONProtocolFactory()
jsProt179 := factory178.GetProtocol(mbTrans176)
factory228 := thrift.NewTJSONProtocolFactory()
jsProt229 := factory228.GetProtocol(mbTrans226)
argvalue0 := aurora.NewJobKey()
err180 := argvalue0.Read(jsProt179)
if err180 != nil {
err230 := argvalue0.Read(context.Background(), jsProt229)
if err230 != nil {
Usage()
return
}
@ -250,19 +250,19 @@ func main() {
fmt.Fprintln(os.Stderr, "StartCronJob requires 1 args")
flag.Usage()
}
arg181 := flag.Arg(1)
mbTrans182 := thrift.NewTMemoryBufferLen(len(arg181))
defer mbTrans182.Close()
_, err183 := mbTrans182.WriteString(arg181)
if err183 != nil {
arg231 := flag.Arg(1)
mbTrans232 := thrift.NewTMemoryBufferLen(len(arg231))
defer mbTrans232.Close()
_, err233 := mbTrans232.WriteString(arg231)
if err233 != nil {
Usage()
return
}
factory184 := thrift.NewTJSONProtocolFactory()
jsProt185 := factory184.GetProtocol(mbTrans182)
factory234 := thrift.NewTJSONProtocolFactory()
jsProt235 := factory234.GetProtocol(mbTrans232)
argvalue0 := aurora.NewJobKey()
err186 := argvalue0.Read(jsProt185)
if err186 != nil {
err236 := argvalue0.Read(context.Background(), jsProt235)
if err236 != nil {
Usage()
return
}
@ -275,36 +275,36 @@ func main() {
fmt.Fprintln(os.Stderr, "RestartShards requires 2 args")
flag.Usage()
}
arg187 := flag.Arg(1)
mbTrans188 := thrift.NewTMemoryBufferLen(len(arg187))
defer mbTrans188.Close()
_, err189 := mbTrans188.WriteString(arg187)
if err189 != nil {
arg237 := flag.Arg(1)
mbTrans238 := thrift.NewTMemoryBufferLen(len(arg237))
defer mbTrans238.Close()
_, err239 := mbTrans238.WriteString(arg237)
if err239 != nil {
Usage()
return
}
factory190 := thrift.NewTJSONProtocolFactory()
jsProt191 := factory190.GetProtocol(mbTrans188)
factory240 := thrift.NewTJSONProtocolFactory()
jsProt241 := factory240.GetProtocol(mbTrans238)
argvalue0 := aurora.NewJobKey()
err192 := argvalue0.Read(jsProt191)
if err192 != nil {
err242 := argvalue0.Read(context.Background(), jsProt241)
if err242 != nil {
Usage()
return
}
value0 := argvalue0
arg193 := flag.Arg(2)
mbTrans194 := thrift.NewTMemoryBufferLen(len(arg193))
defer mbTrans194.Close()
_, err195 := mbTrans194.WriteString(arg193)
if err195 != nil {
arg243 := flag.Arg(2)
mbTrans244 := thrift.NewTMemoryBufferLen(len(arg243))
defer mbTrans244.Close()
_, err245 := mbTrans244.WriteString(arg243)
if err245 != nil {
Usage()
return
}
factory196 := thrift.NewTJSONProtocolFactory()
jsProt197 := factory196.GetProtocol(mbTrans194)
factory246 := thrift.NewTJSONProtocolFactory()
jsProt247 := factory246.GetProtocol(mbTrans244)
containerStruct1 := aurora.NewAuroraSchedulerManagerRestartShardsArgs()
err198 := containerStruct1.ReadField2(jsProt197)
if err198 != nil {
err248 := containerStruct1.ReadField2(context.Background(), jsProt247)
if err248 != nil {
Usage()
return
}
@ -318,36 +318,36 @@ func main() {
fmt.Fprintln(os.Stderr, "KillTasks requires 3 args")
flag.Usage()
}
arg199 := flag.Arg(1)
mbTrans200 := thrift.NewTMemoryBufferLen(len(arg199))
defer mbTrans200.Close()
_, err201 := mbTrans200.WriteString(arg199)
if err201 != nil {
arg249 := flag.Arg(1)
mbTrans250 := thrift.NewTMemoryBufferLen(len(arg249))
defer mbTrans250.Close()
_, err251 := mbTrans250.WriteString(arg249)
if err251 != nil {
Usage()
return
}
factory202 := thrift.NewTJSONProtocolFactory()
jsProt203 := factory202.GetProtocol(mbTrans200)
factory252 := thrift.NewTJSONProtocolFactory()
jsProt253 := factory252.GetProtocol(mbTrans250)
argvalue0 := aurora.NewJobKey()
err204 := argvalue0.Read(jsProt203)
if err204 != nil {
err254 := argvalue0.Read(context.Background(), jsProt253)
if err254 != nil {
Usage()
return
}
value0 := argvalue0
arg205 := flag.Arg(2)
mbTrans206 := thrift.NewTMemoryBufferLen(len(arg205))
defer mbTrans206.Close()
_, err207 := mbTrans206.WriteString(arg205)
if err207 != nil {
arg255 := flag.Arg(2)
mbTrans256 := thrift.NewTMemoryBufferLen(len(arg255))
defer mbTrans256.Close()
_, err257 := mbTrans256.WriteString(arg255)
if err257 != nil {
Usage()
return
}
factory208 := thrift.NewTJSONProtocolFactory()
jsProt209 := factory208.GetProtocol(mbTrans206)
factory258 := thrift.NewTJSONProtocolFactory()
jsProt259 := factory258.GetProtocol(mbTrans256)
containerStruct1 := aurora.NewAuroraSchedulerManagerKillTasksArgs()
err210 := containerStruct1.ReadField2(jsProt209)
if err210 != nil {
err260 := containerStruct1.ReadField2(context.Background(), jsProt259)
if err260 != nil {
Usage()
return
}
@ -363,25 +363,25 @@ func main() {
fmt.Fprintln(os.Stderr, "AddInstances requires 2 args")
flag.Usage()
}
arg212 := flag.Arg(1)
mbTrans213 := thrift.NewTMemoryBufferLen(len(arg212))
defer mbTrans213.Close()
_, err214 := mbTrans213.WriteString(arg212)
if err214 != nil {
arg262 := flag.Arg(1)
mbTrans263 := thrift.NewTMemoryBufferLen(len(arg262))
defer mbTrans263.Close()
_, err264 := mbTrans263.WriteString(arg262)
if err264 != nil {
Usage()
return
}
factory215 := thrift.NewTJSONProtocolFactory()
jsProt216 := factory215.GetProtocol(mbTrans213)
factory265 := thrift.NewTJSONProtocolFactory()
jsProt266 := factory265.GetProtocol(mbTrans263)
argvalue0 := aurora.NewInstanceKey()
err217 := argvalue0.Read(jsProt216)
if err217 != nil {
err267 := argvalue0.Read(context.Background(), jsProt266)
if err267 != nil {
Usage()
return
}
value0 := argvalue0
tmp1, err218 := (strconv.Atoi(flag.Arg(2)))
if err218 != nil {
tmp1, err268 := (strconv.Atoi(flag.Arg(2)))
if err268 != nil {
Usage()
return
}
@ -395,19 +395,19 @@ func main() {
fmt.Fprintln(os.Stderr, "ReplaceCronTemplate requires 1 args")
flag.Usage()
}
arg219 := flag.Arg(1)
mbTrans220 := thrift.NewTMemoryBufferLen(len(arg219))
defer mbTrans220.Close()
_, err221 := mbTrans220.WriteString(arg219)
if err221 != nil {
arg269 := flag.Arg(1)
mbTrans270 := thrift.NewTMemoryBufferLen(len(arg269))
defer mbTrans270.Close()
_, err271 := mbTrans270.WriteString(arg269)
if err271 != nil {
Usage()
return
}
factory222 := thrift.NewTJSONProtocolFactory()
jsProt223 := factory222.GetProtocol(mbTrans220)
factory272 := thrift.NewTJSONProtocolFactory()
jsProt273 := factory272.GetProtocol(mbTrans270)
argvalue0 := aurora.NewJobConfiguration()
err224 := argvalue0.Read(jsProt223)
if err224 != nil {
err274 := argvalue0.Read(context.Background(), jsProt273)
if err274 != nil {
Usage()
return
}
@ -420,19 +420,19 @@ func main() {
fmt.Fprintln(os.Stderr, "StartJobUpdate requires 2 args")
flag.Usage()
}
arg225 := flag.Arg(1)
mbTrans226 := thrift.NewTMemoryBufferLen(len(arg225))
defer mbTrans226.Close()
_, err227 := mbTrans226.WriteString(arg225)
if err227 != nil {
arg275 := flag.Arg(1)
mbTrans276 := thrift.NewTMemoryBufferLen(len(arg275))
defer mbTrans276.Close()
_, err277 := mbTrans276.WriteString(arg275)
if err277 != nil {
Usage()
return
}
factory228 := thrift.NewTJSONProtocolFactory()
jsProt229 := factory228.GetProtocol(mbTrans226)
factory278 := thrift.NewTJSONProtocolFactory()
jsProt279 := factory278.GetProtocol(mbTrans276)
argvalue0 := aurora.NewJobUpdateRequest()
err230 := argvalue0.Read(jsProt229)
if err230 != nil {
err280 := argvalue0.Read(context.Background(), jsProt279)
if err280 != nil {
Usage()
return
}
@ -447,19 +447,19 @@ func main() {
fmt.Fprintln(os.Stderr, "PauseJobUpdate requires 2 args")
flag.Usage()
}
arg232 := flag.Arg(1)
mbTrans233 := thrift.NewTMemoryBufferLen(len(arg232))
defer mbTrans233.Close()
_, err234 := mbTrans233.WriteString(arg232)
if err234 != nil {
arg282 := flag.Arg(1)
mbTrans283 := thrift.NewTMemoryBufferLen(len(arg282))
defer mbTrans283.Close()
_, err284 := mbTrans283.WriteString(arg282)
if err284 != nil {
Usage()
return
}
factory235 := thrift.NewTJSONProtocolFactory()
jsProt236 := factory235.GetProtocol(mbTrans233)
factory285 := thrift.NewTJSONProtocolFactory()
jsProt286 := factory285.GetProtocol(mbTrans283)
argvalue0 := aurora.NewJobUpdateKey()
err237 := argvalue0.Read(jsProt236)
if err237 != nil {
err287 := argvalue0.Read(context.Background(), jsProt286)
if err287 != nil {
Usage()
return
}
@ -474,19 +474,19 @@ func main() {
fmt.Fprintln(os.Stderr, "ResumeJobUpdate requires 2 args")
flag.Usage()
}
arg239 := flag.Arg(1)
mbTrans240 := thrift.NewTMemoryBufferLen(len(arg239))
defer mbTrans240.Close()
_, err241 := mbTrans240.WriteString(arg239)
if err241 != nil {
arg289 := flag.Arg(1)
mbTrans290 := thrift.NewTMemoryBufferLen(len(arg289))
defer mbTrans290.Close()
_, err291 := mbTrans290.WriteString(arg289)
if err291 != nil {
Usage()
return
}
factory242 := thrift.NewTJSONProtocolFactory()
jsProt243 := factory242.GetProtocol(mbTrans240)
factory292 := thrift.NewTJSONProtocolFactory()
jsProt293 := factory292.GetProtocol(mbTrans290)
argvalue0 := aurora.NewJobUpdateKey()
err244 := argvalue0.Read(jsProt243)
if err244 != nil {
err294 := argvalue0.Read(context.Background(), jsProt293)
if err294 != nil {
Usage()
return
}
@ -501,19 +501,19 @@ func main() {
fmt.Fprintln(os.Stderr, "AbortJobUpdate requires 2 args")
flag.Usage()
}
arg246 := flag.Arg(1)
mbTrans247 := thrift.NewTMemoryBufferLen(len(arg246))
defer mbTrans247.Close()
_, err248 := mbTrans247.WriteString(arg246)
if err248 != nil {
arg296 := flag.Arg(1)
mbTrans297 := thrift.NewTMemoryBufferLen(len(arg296))
defer mbTrans297.Close()
_, err298 := mbTrans297.WriteString(arg296)
if err298 != nil {
Usage()
return
}
factory249 := thrift.NewTJSONProtocolFactory()
jsProt250 := factory249.GetProtocol(mbTrans247)
factory299 := thrift.NewTJSONProtocolFactory()
jsProt300 := factory299.GetProtocol(mbTrans297)
argvalue0 := aurora.NewJobUpdateKey()
err251 := argvalue0.Read(jsProt250)
if err251 != nil {
err301 := argvalue0.Read(context.Background(), jsProt300)
if err301 != nil {
Usage()
return
}
@ -528,19 +528,19 @@ func main() {
fmt.Fprintln(os.Stderr, "RollbackJobUpdate requires 2 args")
flag.Usage()
}
arg253 := flag.Arg(1)
mbTrans254 := thrift.NewTMemoryBufferLen(len(arg253))
defer mbTrans254.Close()
_, err255 := mbTrans254.WriteString(arg253)
if err255 != nil {
arg303 := flag.Arg(1)
mbTrans304 := thrift.NewTMemoryBufferLen(len(arg303))
defer mbTrans304.Close()
_, err305 := mbTrans304.WriteString(arg303)
if err305 != nil {
Usage()
return
}
factory256 := thrift.NewTJSONProtocolFactory()
jsProt257 := factory256.GetProtocol(mbTrans254)
factory306 := thrift.NewTJSONProtocolFactory()
jsProt307 := factory306.GetProtocol(mbTrans304)
argvalue0 := aurora.NewJobUpdateKey()
err258 := argvalue0.Read(jsProt257)
if err258 != nil {
err308 := argvalue0.Read(context.Background(), jsProt307)
if err308 != nil {
Usage()
return
}
@ -555,19 +555,19 @@ func main() {
fmt.Fprintln(os.Stderr, "PulseJobUpdate requires 1 args")
flag.Usage()
}
arg260 := flag.Arg(1)
mbTrans261 := thrift.NewTMemoryBufferLen(len(arg260))
defer mbTrans261.Close()
_, err262 := mbTrans261.WriteString(arg260)
if err262 != nil {
arg310 := flag.Arg(1)
mbTrans311 := thrift.NewTMemoryBufferLen(len(arg310))
defer mbTrans311.Close()
_, err312 := mbTrans311.WriteString(arg310)
if err312 != nil {
Usage()
return
}
factory263 := thrift.NewTJSONProtocolFactory()
jsProt264 := factory263.GetProtocol(mbTrans261)
factory313 := thrift.NewTJSONProtocolFactory()
jsProt314 := factory313.GetProtocol(mbTrans311)
argvalue0 := aurora.NewJobUpdateKey()
err265 := argvalue0.Read(jsProt264)
if err265 != nil {
err315 := argvalue0.Read(context.Background(), jsProt314)
if err315 != nil {
Usage()
return
}
@ -598,19 +598,19 @@ func main() {
fmt.Fprintln(os.Stderr, "GetTasksStatus requires 1 args")
flag.Usage()
}
arg267 := flag.Arg(1)
mbTrans268 := thrift.NewTMemoryBufferLen(len(arg267))
defer mbTrans268.Close()
_, err269 := mbTrans268.WriteString(arg267)
if err269 != nil {
arg317 := flag.Arg(1)
mbTrans318 := thrift.NewTMemoryBufferLen(len(arg317))
defer mbTrans318.Close()
_, err319 := mbTrans318.WriteString(arg317)
if err319 != nil {
Usage()
return
}
factory270 := thrift.NewTJSONProtocolFactory()
jsProt271 := factory270.GetProtocol(mbTrans268)
factory320 := thrift.NewTJSONProtocolFactory()
jsProt321 := factory320.GetProtocol(mbTrans318)
argvalue0 := aurora.NewTaskQuery()
err272 := argvalue0.Read(jsProt271)
if err272 != nil {
err322 := argvalue0.Read(context.Background(), jsProt321)
if err322 != nil {
Usage()
return
}
@ -623,19 +623,19 @@ func main() {
fmt.Fprintln(os.Stderr, "GetTasksWithoutConfigs requires 1 args")
flag.Usage()
}
arg273 := flag.Arg(1)
mbTrans274 := thrift.NewTMemoryBufferLen(len(arg273))
defer mbTrans274.Close()
_, err275 := mbTrans274.WriteString(arg273)
if err275 != nil {
arg323 := flag.Arg(1)
mbTrans324 := thrift.NewTMemoryBufferLen(len(arg323))
defer mbTrans324.Close()
_, err325 := mbTrans324.WriteString(arg323)
if err325 != nil {
Usage()
return
}
factory276 := thrift.NewTJSONProtocolFactory()
jsProt277 := factory276.GetProtocol(mbTrans274)
factory326 := thrift.NewTJSONProtocolFactory()
jsProt327 := factory326.GetProtocol(mbTrans324)
argvalue0 := aurora.NewTaskQuery()
err278 := argvalue0.Read(jsProt277)
if err278 != nil {
err328 := argvalue0.Read(context.Background(), jsProt327)
if err328 != nil {
Usage()
return
}
@ -648,19 +648,19 @@ func main() {
fmt.Fprintln(os.Stderr, "GetPendingReason requires 1 args")
flag.Usage()
}
arg279 := flag.Arg(1)
mbTrans280 := thrift.NewTMemoryBufferLen(len(arg279))
defer mbTrans280.Close()
_, err281 := mbTrans280.WriteString(arg279)
if err281 != nil {
arg329 := flag.Arg(1)
mbTrans330 := thrift.NewTMemoryBufferLen(len(arg329))
defer mbTrans330.Close()
_, err331 := mbTrans330.WriteString(arg329)
if err331 != nil {
Usage()
return
}
factory282 := thrift.NewTJSONProtocolFactory()
jsProt283 := factory282.GetProtocol(mbTrans280)
factory332 := thrift.NewTJSONProtocolFactory()
jsProt333 := factory332.GetProtocol(mbTrans330)
argvalue0 := aurora.NewTaskQuery()
err284 := argvalue0.Read(jsProt283)
if err284 != nil {
err334 := argvalue0.Read(context.Background(), jsProt333)
if err334 != nil {
Usage()
return
}
@ -673,19 +673,19 @@ func main() {
fmt.Fprintln(os.Stderr, "GetConfigSummary requires 1 args")
flag.Usage()
}
arg285 := flag.Arg(1)
mbTrans286 := thrift.NewTMemoryBufferLen(len(arg285))
defer mbTrans286.Close()
_, err287 := mbTrans286.WriteString(arg285)
if err287 != nil {
arg335 := flag.Arg(1)
mbTrans336 := thrift.NewTMemoryBufferLen(len(arg335))
defer mbTrans336.Close()
_, err337 := mbTrans336.WriteString(arg335)
if err337 != nil {
Usage()
return
}
factory288 := thrift.NewTJSONProtocolFactory()
jsProt289 := factory288.GetProtocol(mbTrans286)
factory338 := thrift.NewTJSONProtocolFactory()
jsProt339 := factory338.GetProtocol(mbTrans336)
argvalue0 := aurora.NewJobKey()
err290 := argvalue0.Read(jsProt289)
if err290 != nil {
err340 := argvalue0.Read(context.Background(), jsProt339)
if err340 != nil {
Usage()
return
}
@ -718,19 +718,19 @@ func main() {
fmt.Fprintln(os.Stderr, "PopulateJobConfig requires 1 args")
flag.Usage()
}
arg293 := flag.Arg(1)
mbTrans294 := thrift.NewTMemoryBufferLen(len(arg293))
defer mbTrans294.Close()
_, err295 := mbTrans294.WriteString(arg293)
if err295 != nil {
arg343 := flag.Arg(1)
mbTrans344 := thrift.NewTMemoryBufferLen(len(arg343))
defer mbTrans344.Close()
_, err345 := mbTrans344.WriteString(arg343)
if err345 != nil {
Usage()
return
}
factory296 := thrift.NewTJSONProtocolFactory()
jsProt297 := factory296.GetProtocol(mbTrans294)
factory346 := thrift.NewTJSONProtocolFactory()
jsProt347 := factory346.GetProtocol(mbTrans344)
argvalue0 := aurora.NewJobConfiguration()
err298 := argvalue0.Read(jsProt297)
if err298 != nil {
err348 := argvalue0.Read(context.Background(), jsProt347)
if err348 != nil {
Usage()
return
}
@ -743,19 +743,19 @@ func main() {
fmt.Fprintln(os.Stderr, "GetJobUpdateSummaries requires 1 args")
flag.Usage()
}
arg299 := flag.Arg(1)
mbTrans300 := thrift.NewTMemoryBufferLen(len(arg299))
defer mbTrans300.Close()
_, err301 := mbTrans300.WriteString(arg299)
if err301 != nil {
arg349 := flag.Arg(1)
mbTrans350 := thrift.NewTMemoryBufferLen(len(arg349))
defer mbTrans350.Close()
_, err351 := mbTrans350.WriteString(arg349)
if err351 != nil {
Usage()
return
}
factory302 := thrift.NewTJSONProtocolFactory()
jsProt303 := factory302.GetProtocol(mbTrans300)
factory352 := thrift.NewTJSONProtocolFactory()
jsProt353 := factory352.GetProtocol(mbTrans350)
argvalue0 := aurora.NewJobUpdateQuery()
err304 := argvalue0.Read(jsProt303)
if err304 != nil {
err354 := argvalue0.Read(context.Background(), jsProt353)
if err354 != nil {
Usage()
return
}
@ -768,19 +768,19 @@ func main() {
fmt.Fprintln(os.Stderr, "GetJobUpdateDetails requires 1 args")
flag.Usage()
}
arg305 := flag.Arg(1)
mbTrans306 := thrift.NewTMemoryBufferLen(len(arg305))
defer mbTrans306.Close()
_, err307 := mbTrans306.WriteString(arg305)
if err307 != nil {
arg355 := flag.Arg(1)
mbTrans356 := thrift.NewTMemoryBufferLen(len(arg355))
defer mbTrans356.Close()
_, err357 := mbTrans356.WriteString(arg355)
if err357 != nil {
Usage()
return
}
factory308 := thrift.NewTJSONProtocolFactory()
jsProt309 := factory308.GetProtocol(mbTrans306)
factory358 := thrift.NewTJSONProtocolFactory()
jsProt359 := factory358.GetProtocol(mbTrans356)
argvalue0 := aurora.NewJobUpdateQuery()
err310 := argvalue0.Read(jsProt309)
if err310 != nil {
err360 := argvalue0.Read(context.Background(), jsProt359)
if err360 != nil {
Usage()
return
}
@ -793,19 +793,19 @@ func main() {
fmt.Fprintln(os.Stderr, "GetJobUpdateDiff requires 1 args")
flag.Usage()
}
arg311 := flag.Arg(1)
mbTrans312 := thrift.NewTMemoryBufferLen(len(arg311))
defer mbTrans312.Close()
_, err313 := mbTrans312.WriteString(arg311)
if err313 != nil {
arg361 := flag.Arg(1)
mbTrans362 := thrift.NewTMemoryBufferLen(len(arg361))
defer mbTrans362.Close()
_, err363 := mbTrans362.WriteString(arg361)
if err363 != nil {
Usage()
return
}
factory314 := thrift.NewTJSONProtocolFactory()
jsProt315 := factory314.GetProtocol(mbTrans312)
factory364 := thrift.NewTJSONProtocolFactory()
jsProt365 := factory364.GetProtocol(mbTrans362)
argvalue0 := aurora.NewJobUpdateRequest()
err316 := argvalue0.Read(jsProt315)
if err316 != nil {
err366 := argvalue0.Read(context.Background(), jsProt365)
if err366 != nil {
Usage()
return
}

View file

@ -1,22 +1,22 @@
// Autogenerated by Thrift Compiler (0.12.0)
// DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING
// Code generated by Thrift Compiler (0.14.0). DO NOT EDIT.
package main
import (
"context"
"flag"
"fmt"
"math"
"net"
"net/url"
"os"
"strconv"
"strings"
"github.com/apache/thrift/lib/go/thrift"
"apache/aurora"
"context"
"flag"
"fmt"
"math"
"net"
"net/url"
"os"
"strconv"
"strings"
"github.com/apache/thrift/lib/go/thrift"
"apache/aurora"
)
var _ = aurora.GoUnusedProtection__
func Usage() {
fmt.Fprintln(os.Stderr, "Usage of ", os.Args[0], " [-h host:port] [-u url] [-f[ramed]] function [arg1 [arg2...]]:")
@ -179,19 +179,19 @@ func main() {
fmt.Fprintln(os.Stderr, "GetTasksStatus requires 1 args")
flag.Usage()
}
arg82 := flag.Arg(1)
mbTrans83 := thrift.NewTMemoryBufferLen(len(arg82))
defer mbTrans83.Close()
_, err84 := mbTrans83.WriteString(arg82)
if err84 != nil {
arg132 := flag.Arg(1)
mbTrans133 := thrift.NewTMemoryBufferLen(len(arg132))
defer mbTrans133.Close()
_, err134 := mbTrans133.WriteString(arg132)
if err134 != nil {
Usage()
return
}
factory85 := thrift.NewTJSONProtocolFactory()
jsProt86 := factory85.GetProtocol(mbTrans83)
factory135 := thrift.NewTJSONProtocolFactory()
jsProt136 := factory135.GetProtocol(mbTrans133)
argvalue0 := aurora.NewTaskQuery()
err87 := argvalue0.Read(jsProt86)
if err87 != nil {
err137 := argvalue0.Read(context.Background(), jsProt136)
if err137 != nil {
Usage()
return
}
@ -204,19 +204,19 @@ func main() {
fmt.Fprintln(os.Stderr, "GetTasksWithoutConfigs requires 1 args")
flag.Usage()
}
arg88 := flag.Arg(1)
mbTrans89 := thrift.NewTMemoryBufferLen(len(arg88))
defer mbTrans89.Close()
_, err90 := mbTrans89.WriteString(arg88)
if err90 != nil {
arg138 := flag.Arg(1)
mbTrans139 := thrift.NewTMemoryBufferLen(len(arg138))
defer mbTrans139.Close()
_, err140 := mbTrans139.WriteString(arg138)
if err140 != nil {
Usage()
return
}
factory91 := thrift.NewTJSONProtocolFactory()
jsProt92 := factory91.GetProtocol(mbTrans89)
factory141 := thrift.NewTJSONProtocolFactory()
jsProt142 := factory141.GetProtocol(mbTrans139)
argvalue0 := aurora.NewTaskQuery()
err93 := argvalue0.Read(jsProt92)
if err93 != nil {
err143 := argvalue0.Read(context.Background(), jsProt142)
if err143 != nil {
Usage()
return
}
@ -229,19 +229,19 @@ func main() {
fmt.Fprintln(os.Stderr, "GetPendingReason requires 1 args")
flag.Usage()
}
arg94 := flag.Arg(1)
mbTrans95 := thrift.NewTMemoryBufferLen(len(arg94))
defer mbTrans95.Close()
_, err96 := mbTrans95.WriteString(arg94)
if err96 != nil {
arg144 := flag.Arg(1)
mbTrans145 := thrift.NewTMemoryBufferLen(len(arg144))
defer mbTrans145.Close()
_, err146 := mbTrans145.WriteString(arg144)
if err146 != nil {
Usage()
return
}
factory97 := thrift.NewTJSONProtocolFactory()
jsProt98 := factory97.GetProtocol(mbTrans95)
factory147 := thrift.NewTJSONProtocolFactory()
jsProt148 := factory147.GetProtocol(mbTrans145)
argvalue0 := aurora.NewTaskQuery()
err99 := argvalue0.Read(jsProt98)
if err99 != nil {
err149 := argvalue0.Read(context.Background(), jsProt148)
if err149 != nil {
Usage()
return
}
@ -254,19 +254,19 @@ func main() {
fmt.Fprintln(os.Stderr, "GetConfigSummary requires 1 args")
flag.Usage()
}
arg100 := flag.Arg(1)
mbTrans101 := thrift.NewTMemoryBufferLen(len(arg100))
defer mbTrans101.Close()
_, err102 := mbTrans101.WriteString(arg100)
if err102 != nil {
arg150 := flag.Arg(1)
mbTrans151 := thrift.NewTMemoryBufferLen(len(arg150))
defer mbTrans151.Close()
_, err152 := mbTrans151.WriteString(arg150)
if err152 != nil {
Usage()
return
}
factory103 := thrift.NewTJSONProtocolFactory()
jsProt104 := factory103.GetProtocol(mbTrans101)
factory153 := thrift.NewTJSONProtocolFactory()
jsProt154 := factory153.GetProtocol(mbTrans151)
argvalue0 := aurora.NewJobKey()
err105 := argvalue0.Read(jsProt104)
if err105 != nil {
err155 := argvalue0.Read(context.Background(), jsProt154)
if err155 != nil {
Usage()
return
}
@ -299,19 +299,19 @@ func main() {
fmt.Fprintln(os.Stderr, "PopulateJobConfig requires 1 args")
flag.Usage()
}
arg108 := flag.Arg(1)
mbTrans109 := thrift.NewTMemoryBufferLen(len(arg108))
defer mbTrans109.Close()
_, err110 := mbTrans109.WriteString(arg108)
if err110 != nil {
arg158 := flag.Arg(1)
mbTrans159 := thrift.NewTMemoryBufferLen(len(arg158))
defer mbTrans159.Close()
_, err160 := mbTrans159.WriteString(arg158)
if err160 != nil {
Usage()
return
}
factory111 := thrift.NewTJSONProtocolFactory()
jsProt112 := factory111.GetProtocol(mbTrans109)
factory161 := thrift.NewTJSONProtocolFactory()
jsProt162 := factory161.GetProtocol(mbTrans159)
argvalue0 := aurora.NewJobConfiguration()
err113 := argvalue0.Read(jsProt112)
if err113 != nil {
err163 := argvalue0.Read(context.Background(), jsProt162)
if err163 != nil {
Usage()
return
}
@ -324,19 +324,19 @@ func main() {
fmt.Fprintln(os.Stderr, "GetJobUpdateSummaries requires 1 args")
flag.Usage()
}
arg114 := flag.Arg(1)
mbTrans115 := thrift.NewTMemoryBufferLen(len(arg114))
defer mbTrans115.Close()
_, err116 := mbTrans115.WriteString(arg114)
if err116 != nil {
arg164 := flag.Arg(1)
mbTrans165 := thrift.NewTMemoryBufferLen(len(arg164))
defer mbTrans165.Close()
_, err166 := mbTrans165.WriteString(arg164)
if err166 != nil {
Usage()
return
}
factory117 := thrift.NewTJSONProtocolFactory()
jsProt118 := factory117.GetProtocol(mbTrans115)
factory167 := thrift.NewTJSONProtocolFactory()
jsProt168 := factory167.GetProtocol(mbTrans165)
argvalue0 := aurora.NewJobUpdateQuery()
err119 := argvalue0.Read(jsProt118)
if err119 != nil {
err169 := argvalue0.Read(context.Background(), jsProt168)
if err169 != nil {
Usage()
return
}
@ -349,19 +349,19 @@ func main() {
fmt.Fprintln(os.Stderr, "GetJobUpdateDetails requires 1 args")
flag.Usage()
}
arg120 := flag.Arg(1)
mbTrans121 := thrift.NewTMemoryBufferLen(len(arg120))
defer mbTrans121.Close()
_, err122 := mbTrans121.WriteString(arg120)
if err122 != nil {
arg170 := flag.Arg(1)
mbTrans171 := thrift.NewTMemoryBufferLen(len(arg170))
defer mbTrans171.Close()
_, err172 := mbTrans171.WriteString(arg170)
if err172 != nil {
Usage()
return
}
factory123 := thrift.NewTJSONProtocolFactory()
jsProt124 := factory123.GetProtocol(mbTrans121)
factory173 := thrift.NewTJSONProtocolFactory()
jsProt174 := factory173.GetProtocol(mbTrans171)
argvalue0 := aurora.NewJobUpdateQuery()
err125 := argvalue0.Read(jsProt124)
if err125 != nil {
err175 := argvalue0.Read(context.Background(), jsProt174)
if err175 != nil {
Usage()
return
}
@ -374,19 +374,19 @@ func main() {
fmt.Fprintln(os.Stderr, "GetJobUpdateDiff requires 1 args")
flag.Usage()
}
arg126 := flag.Arg(1)
mbTrans127 := thrift.NewTMemoryBufferLen(len(arg126))
defer mbTrans127.Close()
_, err128 := mbTrans127.WriteString(arg126)
if err128 != nil {
arg176 := flag.Arg(1)
mbTrans177 := thrift.NewTMemoryBufferLen(len(arg176))
defer mbTrans177.Close()
_, err178 := mbTrans177.WriteString(arg176)
if err178 != nil {
Usage()
return
}
factory129 := thrift.NewTJSONProtocolFactory()
jsProt130 := factory129.GetProtocol(mbTrans127)
factory179 := thrift.NewTJSONProtocolFactory()
jsProt180 := factory179.GetProtocol(mbTrans177)
argvalue0 := aurora.NewJobUpdateRequest()
err131 := argvalue0.Read(jsProt130)
if err131 != nil {
err181 := argvalue0.Read(context.Background(), jsProt180)
if err181 != nil {
Usage()
return
}

View file

@ -1,6 +1,6 @@
#! /bin/bash
THRIFT_VER=0.12.0
THRIFT_VER=0.14.0
if [[ $(thrift -version | grep -e $THRIFT_VER -c) -ne 1 ]]; then
echo "Warning: This wrapper has only been tested with version" $THRIFT_VER;

14
go.mod
View file

@ -1,12 +1,10 @@
module github.com/paypal/gorealis
go 1.12
module github.com/aurora-scheduler/gorealis/v2
require (
github.com/apache/thrift v0.12.0
github.com/davecgh/go-spew v1.1.0
github.com/pkg/errors v0.0.0-20171216070316-e881fd58d78e
github.com/pmezard/go-difflib v1.0.0
github.com/apache/thrift v0.14.0
github.com/pkg/errors v0.9.1
github.com/samuel/go-zookeeper v0.0.0-20171117190445-471cd4e61d7a
github.com/stretchr/testify v1.2.0
github.com/stretchr/testify v1.7.0
)
go 1.16

23
go.sum
View file

@ -1,9 +1,22 @@
github.com/apache/thrift v0.12.0 h1:pODnxUFNcjP9UTLZGTdeh+j16A8lJbRvD3rOtrk/7bs=
github.com/apache/thrift v0.12.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ=
github.com/apache/thrift v0.14.0 h1:vqZ2DP42i8th2OsgCcYZkirtbzvpZEFx53LiWDJXIAs=
github.com/apache/thrift v0.14.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ=
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/pkg/errors v0.0.0-20171216070316-e881fd58d78e h1:+RHxT/gm0O3UF7nLJbdNzAmULvCFt4XfXHWzh3XI/zs=
github.com/pkg/errors v0.0.0-20171216070316-e881fd58d78e/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/samuel/go-zookeeper v0.0.0-20171117190445-471cd4e61d7a h1:EYL2xz/Zdo0hyqdZMXR4lmT2O11jDLTPCEqIe/FR6W4=
github.com/samuel/go-zookeeper v0.0.0-20171117190445-471cd4e61d7a/go.mod h1:gi+0XIa01GRL2eRQVjQkKGqKF3SF9vZR/HnPullcV2E=
github.com/stretchr/testify v1.2.0/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.5.0 h1:DMOzIV76tmoDNE9pX6RSN0aDtCYeCg5VueieJaAo1uw=
github.com/stretchr/testify v1.5.0/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

23
helpers.go Normal file
View file

@ -0,0 +1,23 @@
package realis
import (
"context"
"github.com/aurora-scheduler/gorealis/v2/gen-go/apache/aurora"
)
func (r *Client) JobExists(key aurora.JobKey) (bool, error) {
resp, err := r.client.GetConfigSummary(context.TODO(), &key)
if err != nil {
return false, err
}
return resp != nil &&
resp.GetResult_() != nil &&
resp.GetResult_().GetConfigSummaryResult_() != nil &&
resp.GetResult_().GetConfigSummaryResult_().GetSummary() != nil &&
resp.GetResult_().GetConfigSummaryResult_().GetSummary().GetGroups() != nil &&
len(resp.GetResult_().GetConfigSummaryResult_().GetSummary().GetGroups()) > 0 &&
resp.GetResponseCode() == aurora.ResponseCode_OK,
nil
}

395
job.go
View file

@ -15,358 +15,213 @@
package realis
import (
"strconv"
"github.com/paypal/gorealis/gen-go/apache/aurora"
"github.com/aurora-scheduler/gorealis/v2/gen-go/apache/aurora"
)
// Job inteface is used to define a set of functions an Aurora Job object
// must implemement.
// TODO(rdelvalle): Consider getting rid of the Job interface
type Job interface {
// Set Job Key environment.
Environment(env string) Job
Role(role string) Job
Name(name string) Job
CronSchedule(cron string) Job
CronCollisionPolicy(policy aurora.CronCollisionPolicy) Job
CPU(cpus float64) Job
Disk(disk int64) Job
RAM(ram int64) Job
GPU(gpu int64) Job
ExecutorName(name string) Job
ExecutorData(data string) Job
AddPorts(num int) Job
AddLabel(key string, value string) Job
AddNamedPorts(names ...string) Job
AddLimitConstraint(name string, limit int32) Job
AddValueConstraint(name string, negated bool, values ...string) Job
// From Aurora Docs:
// dedicated attribute. Aurora treats this specially, and only allows matching jobs
// to run on these machines, and will only schedule matching jobs on these machines.
// When a job is created, the scheduler requires that the $role component matches
// the role field in the job configuration, and will reject the job creation otherwise.
// A wildcard (*) may be used for the role portion of the dedicated attribute, which
// will allow any owner to elect for a job to run on the host(s)
AddDedicatedConstraint(role, name string) Job
AddURIs(extract bool, cache bool, values ...string) Job
JobKey() *aurora.JobKey
JobConfig() *aurora.JobConfiguration
TaskConfig() *aurora.TaskConfig
IsService(isService bool) Job
InstanceCount(instCount int32) Job
GetInstanceCount() int32
MaxFailure(maxFail int32) Job
Container(container Container) Job
PartitionPolicy(policy *aurora.PartitionPolicy) Job
Tier(tier string) Job
SlaPolicy(policy *aurora.SlaPolicy) Job
}
type resourceType int
const (
CPU resourceType = iota
RAM
DISK
GPU
)
// AuroraJob is a structure to collect all information pertaining to an Aurora job.
// Structure to collect all information pertaining to an Aurora job.
type AuroraJob struct {
jobConfig *aurora.JobConfiguration
resources map[resourceType]*aurora.Resource
metadata map[string]*aurora.Metadata
portCount int
task *AuroraTask
}
// NewJob is used to create a Job object with everything initialized.
func NewJob() Job {
jobConfig := aurora.NewJobConfiguration()
taskConfig := aurora.NewTaskConfig()
jobKey := aurora.NewJobKey()
// Create a AuroraJob object with everything initialized.
func NewJob() *AuroraJob {
// Job Config
jobConfig.Key = jobKey
jobConfig.TaskConfig = taskConfig
jobKey := &aurora.JobKey{}
// Task Config
taskConfig.Job = jobKey
taskConfig.Container = aurora.NewContainer()
taskConfig.Container.Mesos = aurora.NewMesosContainer()
// AuroraTask clientConfig
task := NewTask()
task.task.Job = jobKey
// Resources
numCpus := aurora.NewResource()
ramMb := aurora.NewResource()
diskMb := aurora.NewResource()
resources := map[resourceType]*aurora.Resource{CPU: numCpus, RAM: ramMb, DISK: diskMb}
taskConfig.Resources = []*aurora.Resource{numCpus, ramMb, diskMb}
numCpus.NumCpus = new(float64)
ramMb.RamMb = new(int64)
diskMb.DiskMb = new(int64)
// AuroraJob clientConfig
jobConfig := &aurora.JobConfiguration{
Key: jobKey,
TaskConfig: task.TaskConfig(),
}
return &AuroraJob{
jobConfig: jobConfig,
resources: resources,
metadata: make(map[string]*aurora.Metadata),
portCount: 0,
task: task,
}
}
// Environment sets the Job Key environment.
func (j *AuroraJob) Environment(env string) Job {
// Set AuroraJob Key environment. Explicit changes to AuroraTask's job key are not needed
// because they share a pointer to the same JobKey.
func (j *AuroraJob) Environment(env string) *AuroraJob {
j.jobConfig.Key.Environment = env
return j
}
// Role sets the Job Key role.
func (j *AuroraJob) Role(role string) Job {
// Set AuroraJob Key Role.
func (j *AuroraJob) Role(role string) *AuroraJob {
j.jobConfig.Key.Role = role
// Will be deprecated
identity := &aurora.Identity{User: role}
j.jobConfig.Owner = identity
j.jobConfig.TaskConfig.Owner = identity
return j
}
// Name sets the Job Key Name.
func (j *AuroraJob) Name(name string) Job {
// Set AuroraJob Key Name.
func (j *AuroraJob) Name(name string) *AuroraJob {
j.jobConfig.Key.Name = name
return j
}
// ExecutorName sets the name of the executor that will the task will be configured to.
func (j *AuroraJob) ExecutorName(name string) Job {
if j.jobConfig.TaskConfig.ExecutorConfig == nil {
j.jobConfig.TaskConfig.ExecutorConfig = aurora.NewExecutorConfig()
}
j.jobConfig.TaskConfig.ExecutorConfig.Name = name
return j
}
// ExecutorData sets the data blob that will be passed to the Mesos executor.
func (j *AuroraJob) ExecutorData(data string) Job {
if j.jobConfig.TaskConfig.ExecutorConfig == nil {
j.jobConfig.TaskConfig.ExecutorConfig = aurora.NewExecutorConfig()
}
j.jobConfig.TaskConfig.ExecutorConfig.Data = data
return j
}
// CPU sets the amount of CPU each task will use in an Aurora Job.
func (j *AuroraJob) CPU(cpus float64) Job {
*j.resources[CPU].NumCpus = cpus
return j
}
// RAM sets the amount of RAM each task will use in an Aurora Job.
func (j *AuroraJob) RAM(ram int64) Job {
*j.resources[RAM].RamMb = ram
return j
}
// Disk sets the amount of Disk each task will use in an Aurora Job.
func (j *AuroraJob) Disk(disk int64) Job {
*j.resources[DISK].DiskMb = disk
return j
}
// GPU sets the amount of GPU each task will use in an Aurora Job.
func (j *AuroraJob) GPU(gpu int64) Job {
// GPU resource must be set explicitly since the scheduler by default
// rejects jobs with GPU resources attached to it.
if _, ok := j.resources[GPU]; !ok {
j.resources[GPU] = &aurora.Resource{}
j.JobConfig().GetTaskConfig().Resources = append(
j.JobConfig().GetTaskConfig().Resources,
j.resources[GPU])
}
j.resources[GPU].NumGpus = &gpu
return j
}
// MaxFailure sets how many failures to tolerate before giving up per Job.
func (j *AuroraJob) MaxFailure(maxFail int32) Job {
j.jobConfig.TaskConfig.MaxTaskFailures = maxFail
return j
}
// InstanceCount sets how many instances of the task to run for this Job.
func (j *AuroraJob) InstanceCount(instCount int32) Job {
// How many instances of the job to run
func (j *AuroraJob) InstanceCount(instCount int32) *AuroraJob {
j.jobConfig.InstanceCount = instCount
return j
}
// CronSchedule allows the user to configure a cron schedule for this job to run in.
func (j *AuroraJob) CronSchedule(cron string) Job {
func (j *AuroraJob) CronSchedule(cron string) *AuroraJob {
j.jobConfig.CronSchedule = &cron
return j
}
// CronCollisionPolicy allows the user to decide what happens if two or more instances
// of the same Cron job need to run.
func (j *AuroraJob) CronCollisionPolicy(policy aurora.CronCollisionPolicy) Job {
func (j *AuroraJob) CronCollisionPolicy(policy aurora.CronCollisionPolicy) *AuroraJob {
j.jobConfig.CronCollisionPolicy = policy
return j
}
// GetInstanceCount returns how many tasks this Job contains.
// How many instances of the job to run
func (j *AuroraJob) GetInstanceCount() int32 {
return j.jobConfig.InstanceCount
}
// IsService returns true if the job is a long term running job or false if it is an ad-hoc job.
func (j *AuroraJob) IsService(isService bool) Job {
j.jobConfig.TaskConfig.IsService = isService
return j
// Get the current job configurations key to use for some realis calls.
func (j *AuroraJob) JobKey() aurora.JobKey {
return *j.jobConfig.Key
}
// JobKey returns the job's configuration key.
func (j *AuroraJob) JobKey() *aurora.JobKey {
return j.jobConfig.Key
}
// JobConfig returns the job's configuration.
// Get the current job configurations key to use for some realis calls.
func (j *AuroraJob) JobConfig() *aurora.JobConfiguration {
return j.jobConfig
}
// TaskConfig returns the job's task(shard) configuration.
// Get the current job configurations key to use for some realis calls.
func (j *AuroraJob) AuroraTask() *AuroraTask {
return j.task
}
/*
AuroraTask specific API, see task.go for further documentation.
These functions are provided for the convenience of chaining API calls.
*/
func (j *AuroraJob) ExecutorName(name string) *AuroraJob {
j.task.ExecutorName(name)
return j
}
func (j *AuroraJob) ExecutorData(data string) *AuroraJob {
j.task.ExecutorData(data)
return j
}
func (j *AuroraJob) CPU(cpus float64) *AuroraJob {
j.task.CPU(cpus)
return j
}
func (j *AuroraJob) RAM(ram int64) *AuroraJob {
j.task.RAM(ram)
return j
}
func (j *AuroraJob) Disk(disk int64) *AuroraJob {
j.task.Disk(disk)
return j
}
func (j *AuroraJob) GPU(gpu int64) *AuroraJob {
j.task.GPU(gpu)
return j
}
func (j *AuroraJob) Tier(tier string) *AuroraJob {
j.task.Tier(tier)
return j
}
func (j *AuroraJob) MaxFailure(maxFail int32) *AuroraJob {
j.task.MaxFailure(maxFail)
return j
}
func (j *AuroraJob) IsService(isService bool) *AuroraJob {
j.task.IsService(isService)
return j
}
func (j *AuroraJob) Priority(priority int32) *AuroraJob {
j.task.Priority(priority)
return j
}
func (j *AuroraJob) Production(production bool) *AuroraJob {
j.task.Production(production)
return j
}
func (j *AuroraJob) TaskConfig() *aurora.TaskConfig {
return j.jobConfig.TaskConfig
return j.task.TaskConfig()
}
// AddURIs adds a list of URIs with the same extract and cache configuration. Scheduler must have
// --enable_mesos_fetcher flag enabled. Currently there is no duplicate detection.
func (j *AuroraJob) AddURIs(extract bool, cache bool, values ...string) Job {
for _, value := range values {
j.jobConfig.TaskConfig.MesosFetcherUris = append(j.jobConfig.TaskConfig.MesosFetcherUris,
&aurora.MesosFetcherURI{Value: value, Extract: &extract, Cache: &cache})
}
func (j *AuroraJob) AddURIs(extract bool, cache bool, values ...string) *AuroraJob {
j.task.AddURIs(extract, cache, values...)
return j
}
// AddLabel adds a Mesos label to the job. Note that Aurora will add the
// prefix "org.apache.aurora.metadata." to the beginning of each key.
func (j *AuroraJob) AddLabel(key string, value string) Job {
if _, ok := j.metadata[key]; ok {
j.metadata[key].Value = value
} else {
j.metadata[key] = &aurora.Metadata{Key: key, Value: value}
j.jobConfig.TaskConfig.Metadata = append(j.jobConfig.TaskConfig.Metadata, j.metadata[key])
}
func (j *AuroraJob) AddLabel(key string, value string) *AuroraJob {
j.task.AddLabel(key, value)
return j
}
// AddNamedPorts adds a named port to the job configuration These are random ports as it's
// not currently possible to request specific ports using Aurora.
func (j *AuroraJob) AddNamedPorts(names ...string) Job {
j.portCount += len(names)
for _, name := range names {
j.jobConfig.TaskConfig.Resources = append(
j.jobConfig.TaskConfig.Resources,
&aurora.Resource{NamedPort: &name})
}
func (j *AuroraJob) AddNamedPorts(names ...string) *AuroraJob {
j.task.AddNamedPorts(names...)
return j
}
// AddPorts adds a request for a number of ports to the job configuration. The names chosen for these ports
// will be org.apache.aurora.port.X, where X is the current port count for the job configuration
// starting at 0. These are random ports as it's not currently possible to request
// specific ports using Aurora.
func (j *AuroraJob) AddPorts(num int) Job {
start := j.portCount
j.portCount += num
for i := start; i < j.portCount; i++ {
portName := "org.apache.aurora.port." + strconv.Itoa(i)
j.jobConfig.TaskConfig.Resources = append(
j.jobConfig.TaskConfig.Resources,
&aurora.Resource{NamedPort: &portName})
}
func (j *AuroraJob) AddPorts(num int) *AuroraJob {
j.task.AddPorts(num)
return j
}
func (j *AuroraJob) AddValueConstraint(name string, negated bool, values ...string) *AuroraJob {
j.task.AddValueConstraint(name, negated, values...)
return j
}
// AddValueConstraint allows the user to add a value constrain to the job to limiti which agents the job's
// tasks can be run on.
// From Aurora Docs:
// Add a Value constraint
// name - Mesos slave attribute that the constraint is matched against.
// If negated = true , treat this as a 'not' - to avoid specific values.
// Values - list of values we look for in attribute name
func (j *AuroraJob) AddValueConstraint(name string, negated bool, values ...string) Job {
j.jobConfig.TaskConfig.Constraints = append(j.jobConfig.TaskConfig.Constraints,
&aurora.Constraint{
Name: name,
Constraint: &aurora.TaskConstraint{
Value: &aurora.ValueConstraint{
Negated: negated,
Values: values,
},
Limit: nil,
},
})
func (j *AuroraJob) AddLimitConstraint(name string, limit int32) *AuroraJob {
j.task.AddLimitConstraint(name, limit)
return j
}
// AddLimitConstraint allows the user to limit how many tasks form the same Job are run on a single host.
// From Aurora Docs:
// A constraint that specifies the maximum number of active tasks on a host with
// a matching attribute that may be scheduled simultaneously.
func (j *AuroraJob) AddLimitConstraint(name string, limit int32) Job {
j.jobConfig.TaskConfig.Constraints = append(j.jobConfig.TaskConfig.Constraints,
&aurora.Constraint{
Name: name,
Constraint: &aurora.TaskConstraint{
Value: nil,
Limit: &aurora.LimitConstraint{Limit: limit},
},
})
func (j *AuroraJob) AddDedicatedConstraint(role, name string) *AuroraJob {
j.task.AddDedicatedConstraint(role, name)
return j
}
// AddDedicatedConstraint allows the user to add a dedicated constraint to a Job configuration.
func (j *AuroraJob) AddDedicatedConstraint(role, name string) Job {
j.AddValueConstraint("dedicated", false, role+"/"+name)
func (j *AuroraJob) Container(container Container) *AuroraJob {
j.task.Container(container)
return j
}
// Container sets a container to run for the job configuration to run.
func (j *AuroraJob) Container(container Container) Job {
j.jobConfig.TaskConfig.Container = container.Build()
func (j *AuroraJob) ThermosExecutor(thermos ThermosExecutor) *AuroraJob {
j.task.ThermosExecutor(thermos)
return j
}
// PartitionPolicy sets a partition policy for the job configuration to implement.
func (j *AuroraJob) PartitionPolicy(policy *aurora.PartitionPolicy) Job {
j.jobConfig.TaskConfig.PartitionPolicy = policy
return j
func (j *AuroraJob) BuildThermosPayload() error {
return j.task.BuildThermosPayload()
}
// Tier sets the Tier for the Job.
func (j *AuroraJob) Tier(tier string) Job {
j.jobConfig.TaskConfig.Tier = &tier
return j
}
// SlaPolicy sets an SlaPolicy for the Job.
func (j *AuroraJob) SlaPolicy(policy *aurora.SlaPolicy) Job {
j.jobConfig.TaskConfig.SlaPolicy = policy
func (j *AuroraJob) PartitionPolicy(reschedule bool, delay int64) *AuroraJob {
j.task.PartitionPolicy(aurora.PartitionPolicy{
Reschedule: reschedule,
DelaySecs: &delay,
})
return j
}

296
jobUpdate.go Normal file
View file

@ -0,0 +1,296 @@
/**
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package realis
import (
"time"
"github.com/apache/thrift/lib/go/thrift"
"github.com/aurora-scheduler/gorealis/v2/gen-go/apache/aurora"
)
// Structure to collect all information required to create job update
type JobUpdate struct {
task *AuroraTask
request *aurora.JobUpdateRequest
}
// Create a default JobUpdate object with an empty task and no fields filled in.
func NewJobUpdate() *JobUpdate {
newTask := NewTask()
return &JobUpdate{
task: newTask,
request: &aurora.JobUpdateRequest{TaskConfig: newTask.TaskConfig(), Settings: newUpdateSettings()},
}
}
// Creates an update with default values using an AuroraTask as the underlying task configuration.
// This function has a high level understanding of Aurora Tasks and thus will support copying a task that is configured
// to use Thermos.
func JobUpdateFromAuroraTask(task *AuroraTask) *JobUpdate {
newTask := task.Clone()
return &JobUpdate{
task: newTask,
request: &aurora.JobUpdateRequest{TaskConfig: newTask.TaskConfig(), Settings: newUpdateSettings()},
}
}
// JobUpdateFromConfig creates an update with default values using an aurora.TaskConfig
// primitive as the underlying task configuration.
// This function should not be used unless the implications of using a primitive value are understood.
// For example, the primitive has no concept of Thermos.
func JobUpdateFromConfig(task *aurora.TaskConfig) *JobUpdate {
// Perform a deep copy to avoid unexpected behavior
newTask := TaskFromThrift(task)
return &JobUpdate{
task: newTask,
request: &aurora.JobUpdateRequest{TaskConfig: newTask.TaskConfig(), Settings: newUpdateSettings()},
}
}
// Set instance count the job will have after the update.
func (j *JobUpdate) InstanceCount(inst int32) *JobUpdate {
j.request.InstanceCount = inst
return j
}
// Max number of instances being updated at any given moment.
func (j *JobUpdate) BatchSize(size int32) *JobUpdate {
j.request.Settings.UpdateGroupSize = size
return j
}
// Minimum number of seconds a shard must remain in RUNNING state before considered a success.
func (j *JobUpdate) WatchTime(timeout time.Duration) *JobUpdate {
j.request.Settings.MinWaitInInstanceRunningMs = int32(timeout.Milliseconds())
return j
}
// Wait for all instances in a group to be done before moving on.
func (j *JobUpdate) WaitForBatchCompletion(batchWait bool) *JobUpdate {
j.request.Settings.WaitForBatchCompletion = batchWait
return j
}
// Max number of instance failures to tolerate before marking instance as FAILED.
func (j *JobUpdate) MaxPerInstanceFailures(inst int32) *JobUpdate {
j.request.Settings.MaxPerInstanceFailures = inst
return j
}
// Max number of FAILED instances to tolerate before terminating the update.
func (j *JobUpdate) MaxFailedInstances(inst int32) *JobUpdate {
j.request.Settings.MaxFailedInstances = inst
return j
}
// When False, prevents auto rollback of a failed update.
func (j *JobUpdate) RollbackOnFail(rollback bool) *JobUpdate {
j.request.Settings.RollbackOnFailure = rollback
return j
}
// Sets the interval at which pulses should be received by the job update before timing out.
func (j *JobUpdate) PulseIntervalTimeout(timeout time.Duration) *JobUpdate {
j.request.Settings.BlockIfNoPulsesAfterMs = thrift.Int32Ptr(int32(timeout.Seconds() * 1000))
return j
}
func (j *JobUpdate) BatchUpdateStrategy(autoPause bool, batchSize int32) *JobUpdate {
j.request.Settings.UpdateStrategy = &aurora.JobUpdateStrategy{
BatchStrategy: &aurora.BatchJobUpdateStrategy{GroupSize: batchSize, AutopauseAfterBatch: autoPause},
}
return j
}
func (j *JobUpdate) QueueUpdateStrategy(groupSize int32) *JobUpdate {
j.request.Settings.UpdateStrategy = &aurora.JobUpdateStrategy{
QueueStrategy: &aurora.QueueJobUpdateStrategy{GroupSize: groupSize},
}
return j
}
func (j *JobUpdate) VariableBatchStrategy(autoPause bool, batchSizes ...int32) *JobUpdate {
j.request.Settings.UpdateStrategy = &aurora.JobUpdateStrategy{
VarBatchStrategy: &aurora.VariableBatchJobUpdateStrategy{GroupSizes: batchSizes, AutopauseAfterBatch: autoPause},
}
return j
}
// SlaAware makes the scheduler enforce the SLA Aware policy if the job meets the SLA awareness criteria.
// By default, the scheduler will only apply SLA Awareness to jobs in the production tier with 20 or more instances.
func (j *JobUpdate) SlaAware(slaAware bool) *JobUpdate {
j.request.Settings.SlaAware = &slaAware
return j
}
// AddInstanceRange allows updates to only touch a certain specific range of instances
func (j *JobUpdate) AddInstanceRange(first, last int32) *JobUpdate {
j.request.Settings.UpdateOnlyTheseInstances = append(j.request.Settings.UpdateOnlyTheseInstances,
&aurora.Range{First: first, Last: last})
return j
}
func newUpdateSettings() *aurora.JobUpdateSettings {
us := aurora.JobUpdateSettings{}
// Mirrors defaults set by Pystachio
us.UpdateOnlyTheseInstances = []*aurora.Range{}
us.UpdateGroupSize = 1
us.WaitForBatchCompletion = false
us.MinWaitInInstanceRunningMs = 45000
us.MaxPerInstanceFailures = 0
us.MaxFailedInstances = 0
us.RollbackOnFailure = true
return &us
}
/*
These methods are provided for user convenience in order to chain
calls for configuration.
API below here are wrappers around modifying an AuroraTask instance.
See task.go for further documentation.
*/
func (j *JobUpdate) Environment(env string) *JobUpdate {
j.task.Environment(env)
return j
}
func (j *JobUpdate) Role(role string) *JobUpdate {
j.task.Role(role)
return j
}
func (j *JobUpdate) Name(name string) *JobUpdate {
j.task.Name(name)
return j
}
func (j *JobUpdate) ExecutorName(name string) *JobUpdate {
j.task.ExecutorName(name)
return j
}
func (j *JobUpdate) ExecutorData(data string) *JobUpdate {
j.task.ExecutorData(data)
return j
}
func (j *JobUpdate) CPU(cpus float64) *JobUpdate {
j.task.CPU(cpus)
return j
}
func (j *JobUpdate) RAM(ram int64) *JobUpdate {
j.task.RAM(ram)
return j
}
func (j *JobUpdate) Disk(disk int64) *JobUpdate {
j.task.Disk(disk)
return j
}
func (j *JobUpdate) Tier(tier string) *JobUpdate {
j.task.Tier(tier)
return j
}
func (j *JobUpdate) TaskMaxFailure(maxFail int32) *JobUpdate {
j.task.MaxFailure(maxFail)
return j
}
func (j *JobUpdate) IsService(isService bool) *JobUpdate {
j.task.IsService(isService)
return j
}
func (j *JobUpdate) Priority(priority int32) *JobUpdate {
j.task.Priority(priority)
return j
}
func (j *JobUpdate) Production(production bool) *JobUpdate {
j.task.Production(production)
return j
}
func (j *JobUpdate) TaskConfig() *aurora.TaskConfig {
return j.task.TaskConfig()
}
func (j *JobUpdate) AddURIs(extract bool, cache bool, values ...string) *JobUpdate {
j.task.AddURIs(extract, cache, values...)
return j
}
func (j *JobUpdate) AddLabel(key string, value string) *JobUpdate {
j.task.AddLabel(key, value)
return j
}
func (j *JobUpdate) AddNamedPorts(names ...string) *JobUpdate {
j.task.AddNamedPorts(names...)
return j
}
func (j *JobUpdate) AddPorts(num int) *JobUpdate {
j.task.AddPorts(num)
return j
}
func (j *JobUpdate) AddValueConstraint(name string, negated bool, values ...string) *JobUpdate {
j.task.AddValueConstraint(name, negated, values...)
return j
}
func (j *JobUpdate) AddLimitConstraint(name string, limit int32) *JobUpdate {
j.task.AddLimitConstraint(name, limit)
return j
}
func (j *JobUpdate) AddDedicatedConstraint(role, name string) *JobUpdate {
j.task.AddDedicatedConstraint(role, name)
return j
}
func (j *JobUpdate) Container(container Container) *JobUpdate {
j.task.Container(container)
return j
}
func (j *JobUpdate) JobKey() aurora.JobKey {
return j.task.JobKey()
}
func (j *JobUpdate) ThermosExecutor(thermos ThermosExecutor) *JobUpdate {
j.task.ThermosExecutor(thermos)
return j
}
func (j *JobUpdate) BuildThermosPayload() error {
return j.task.BuildThermosPayload()
}
func (j *JobUpdate) PartitionPolicy(reschedule bool, delay int64) *JobUpdate {
j.task.PartitionPolicy(aurora.PartitionPolicy{
Reschedule: reschedule,
DelaySecs: &delay,
})
return j
}

View file

@ -14,73 +14,65 @@
package realis
type logger interface {
type Logger interface {
Println(v ...interface{})
Printf(format string, v ...interface{})
Print(v ...interface{})
}
// NoopLogger is a logger that can be attached to the client which will not print anything.
type NoopLogger struct{}
// Printf is a NOOP function here.
func (NoopLogger) Printf(format string, a ...interface{}) {}
// Print is a NOOP function here.
func (NoopLogger) Print(a ...interface{}) {}
// Println is a NOOP function here.
func (NoopLogger) Println(a ...interface{}) {}
// LevelLogger is a logger that can be configured to output different levels of information: Debug and Trace.
// Trace should only be enabled when very in depth information about the sequence of events a function took is needed.
type LevelLogger struct {
logger
Logger
debug bool
trace bool
}
// EnableDebug enables debug level logging for the LevelLogger
func (l *LevelLogger) EnableDebug(enable bool) {
l.debug = enable
}
// EnableTrace enables trace level logging for the LevelLogger
func (l *LevelLogger) EnableTrace(enable bool) {
l.trace = enable
}
func (l LevelLogger) debugPrintf(format string, a ...interface{}) {
func (l LevelLogger) DebugPrintf(format string, a ...interface{}) {
if l.debug {
l.Printf("[DEBUG] "+format, a...)
}
}
func (l LevelLogger) debugPrint(a ...interface{}) {
func (l LevelLogger) DebugPrint(a ...interface{}) {
if l.debug {
l.Print(append([]interface{}{"[DEBUG] "}, a...)...)
}
}
func (l LevelLogger) debugPrintln(a ...interface{}) {
func (l LevelLogger) DebugPrintln(a ...interface{}) {
if l.debug {
l.Println(append([]interface{}{"[DEBUG] "}, a...)...)
}
}
func (l LevelLogger) tracePrintf(format string, a ...interface{}) {
func (l LevelLogger) TracePrintf(format string, a ...interface{}) {
if l.trace {
l.Printf("[TRACE] "+format, a...)
}
}
func (l LevelLogger) tracePrint(a ...interface{}) {
func (l LevelLogger) TracePrint(a ...interface{}) {
if l.trace {
l.Print(append([]interface{}{"[TRACE] "}, a...)...)
}
}
func (l LevelLogger) tracePrintln(a ...interface{}) {
func (l LevelLogger) TracePrintln(a ...interface{}) {
if l.trace {
l.Println(append([]interface{}{"[TRACE] "}, a...)...)
}

View file

@ -12,46 +12,48 @@
* limitations under the License.
*/
// Collection of monitors to create synchronicity
package realis
import (
"time"
"github.com/paypal/gorealis/gen-go/apache/aurora"
"github.com/aurora-scheduler/gorealis/v2/gen-go/apache/aurora"
"github.com/pkg/errors"
)
// Monitor is a wrapper for the Realis client which allows us to have functions
// with the same name for Monitoring purposes.
// TODO(rdelvalle): Deprecate monitors and instead add prefix Monitor to
// all functions in this file like it is done in V2.
type Monitor struct {
Client Realis
}
// JobUpdate polls the scheduler every certain amount of time to see if the update has entered a terminal state.
func (m *Monitor) JobUpdate(
updateKey aurora.JobUpdateKey,
interval int,
timeout int) (bool, error) {
updateQ := aurora.JobUpdateQuery{
Key: &updateKey,
Limit: 1,
UpdateStatuses: TerminalUpdateStates(),
// MonitorJobUpdate polls the scheduler every certain amount of time to see if the update has succeeded.
// If the update entered a terminal update state but it is not ROLLED_FORWARD, this function will return an error.
func (c *Client) MonitorJobUpdate(updateKey aurora.JobUpdateKey, interval, timeout time.Duration) (bool, error) {
if interval < 1*time.Second {
interval = interval * time.Second
}
updateSummaries, err := m.JobUpdateQuery(
updateQ,
time.Duration(interval)*time.Second,
time.Duration(timeout)*time.Second)
status := updateSummaries[0].State.Status
if timeout < 1*time.Second {
timeout = timeout * time.Second
}
updateSummaries, err := c.MonitorJobUpdateQuery(
aurora.JobUpdateQuery{
Key: &updateKey,
Limit: 1,
UpdateStatuses: []aurora.JobUpdateStatus{
aurora.JobUpdateStatus_ROLLED_FORWARD,
aurora.JobUpdateStatus_ROLLED_BACK,
aurora.JobUpdateStatus_ABORTED,
aurora.JobUpdateStatus_ERROR,
aurora.JobUpdateStatus_FAILED,
},
},
interval,
timeout)
if err != nil {
return false, err
}
m.Client.RealisConfig().logger.Printf("job update status: %v\n", status)
status := updateSummaries[0].State.Status
c.RealisConfig().logger.Printf("job update status: %v\n", status)
// Rolled forward is the only state in which an update has been successfully updated
// if we encounter an inactive state and it is not at rolled forward, update failed
@ -68,22 +70,41 @@ func (m *Monitor) JobUpdate(
}
}
// JobUpdateStatus polls the scheduler every certain amount of time to see if the update has entered a specified state.
func (m *Monitor) JobUpdateStatus(updateKey aurora.JobUpdateKey,
// MonitorJobUpdateStatus polls the scheduler for information about an update until the update enters one of the
// desired states or until the function times out.
func (c *Client) MonitorJobUpdateStatus(updateKey aurora.JobUpdateKey,
desiredStatuses []aurora.JobUpdateStatus,
interval, timeout time.Duration) (aurora.JobUpdateStatus, error) {
if len(desiredStatuses) == 0 {
return aurora.JobUpdateStatus(-1), errors.New("no desired statuses provided")
}
// Make deep local copy to avoid side effects from job key being manipulated externally.
updateKeyLocal := &aurora.JobUpdateKey{
Job: &aurora.JobKey{
Role: updateKey.Job.GetRole(),
Environment: updateKey.Job.GetEnvironment(),
Name: updateKey.Job.GetName(),
},
ID: updateKey.GetID(),
}
updateQ := aurora.JobUpdateQuery{
Key: &updateKey,
Key: updateKeyLocal,
Limit: 1,
UpdateStatuses: desiredStatuses,
}
summary, err := m.JobUpdateQuery(updateQ, interval, timeout)
return summary[0].State.Status, err
summary, err := c.MonitorJobUpdateQuery(updateQ, interval, timeout)
if len(summary) > 0 {
return summary[0].State.Status, err
}
return aurora.JobUpdateStatus(-1), err
}
// JobUpdateQuery polls the scheduler every certain amount of time to see if the query call returns any results.
func (m *Monitor) JobUpdateQuery(
func (c *Client) MonitorJobUpdateQuery(
updateQuery aurora.JobUpdateQuery,
interval time.Duration,
timeout time.Duration) ([]*aurora.JobUpdateSummary, error) {
@ -92,20 +113,16 @@ func (m *Monitor) JobUpdateQuery(
defer ticker.Stop()
timer := time.NewTimer(timeout)
defer timer.Stop()
var cliErr error
var respDetail *aurora.Response
for {
select {
case <-ticker.C:
respDetail, cliErr = m.Client.GetJobUpdateSummaries(&updateQuery)
updateSummaryResults, cliErr := c.GetJobUpdateSummaries(&updateQuery)
if cliErr != nil {
return nil, cliErr
}
updateSummaries := respDetail.Result_.GetJobUpdateSummariesResult_.UpdateSummaries
if len(updateSummaries) >= 1 {
return updateSummaries, nil
if len(updateSummaryResults.GetUpdateSummaries()) >= 1 {
return updateSummaryResults.GetUpdateSummaries(), nil
}
case <-timer.C:
@ -114,104 +131,37 @@ func (m *Monitor) JobUpdateQuery(
}
}
// AutoPaused monitor is a special monitor for auto pause enabled batch updates. This monitor ensures that the update
// being monitored is capable of auto pausing and has auto pausing enabled. After verifying this information,
// the monitor watches for the job to enter the ROLL_FORWARD_PAUSED state and calculates the current batch
// the update is in using information from the update configuration.
func (m *Monitor) AutoPausedUpdateMonitor(key aurora.JobUpdateKey, interval, timeout time.Duration) (int, error) {
key.Job = &aurora.JobKey{
Role: key.Job.Role,
Environment: key.Job.Environment,
Name: key.Job.Name,
}
query := aurora.JobUpdateQuery{
UpdateStatuses: aurora.ACTIVE_JOB_UPDATE_STATES,
Limit: 1,
Key: &key,
}
response, err := m.Client.JobUpdateDetails(query)
if err != nil {
return -1, errors.Wrap(err, "unable to get information about update")
}
// TODO (rdelvalle): check for possible nil values when going down the list of structs
updateDetails := response.Result_.GetJobUpdateDetailsResult_.DetailsList
if len(updateDetails) == 0 {
return -1, errors.Errorf("details for update could not be found")
}
updateStrategy := updateDetails[0].Update.Instructions.Settings.UpdateStrategy
var batchSizes []int32
switch {
case updateStrategy.IsSetVarBatchStrategy():
batchSizes = updateStrategy.VarBatchStrategy.GroupSizes
if !updateStrategy.VarBatchStrategy.AutopauseAfterBatch {
return -1, errors.Errorf("update does not have auto pause enabled")
}
case updateStrategy.IsSetBatchStrategy():
batchSizes = []int32{updateStrategy.BatchStrategy.GroupSize}
if !updateStrategy.BatchStrategy.AutopauseAfterBatch {
return -1, errors.Errorf("update does not have auto pause enabled")
}
default:
return -1, errors.Errorf("update is not using a batch update strategy")
}
query.UpdateStatuses = append(TerminalUpdateStates(), aurora.JobUpdateStatus_ROLL_FORWARD_PAUSED)
summary, err := m.JobUpdateQuery(query, interval, timeout)
if err != nil {
return -1, err
}
if summary[0].State.Status != aurora.JobUpdateStatus_ROLL_FORWARD_PAUSED {
return -1, errors.Errorf("update is in a terminal state %v", summary[0].State.Status)
}
updatingInstances := make(map[int32]struct{})
for _, e := range updateDetails[0].InstanceEvents {
// We only care about INSTANCE_UPDATING actions because we only care that they've been attempted
if e != nil && e.GetAction() == aurora.JobUpdateAction_INSTANCE_UPDATING {
updatingInstances[e.GetInstanceId()] = struct{}{}
}
}
return calculateCurrentBatch(int32(len(updatingInstances)), batchSizes), nil
// Monitor a AuroraJob until all instances enter one of the LiveStates
func (c *Client) MonitorInstances(key aurora.JobKey, instances int32, interval, timeout time.Duration) (bool, error) {
return c.MonitorScheduleStatus(key, instances, aurora.LIVE_STATES, interval, timeout)
}
// Monitor a Job until all instances enter one of the LIVE_STATES
func (m *Monitor) Instances(key *aurora.JobKey, instances int32, interval, timeout int) (bool, error) {
return m.ScheduleStatus(key, instances, LiveStates, interval, timeout)
}
// ScheduleStatus will monitor a Job until all instances enter a desired status.
// Monitor a AuroraJob until all instances enter a desired status.
// Defaults sets of desired statuses provided by the thrift API include:
// ACTIVE_STATES, SLAVE_ASSIGNED_STATES, LIVE_STATES, and TERMINAL_STATES
func (m *Monitor) ScheduleStatus(
key *aurora.JobKey,
// ActiveStates, SlaveAssignedStates, LiveStates, and TerminalStates
func (c *Client) MonitorScheduleStatus(key aurora.JobKey,
instanceCount int32,
desiredStatuses map[aurora.ScheduleStatus]bool,
interval int,
timeout int) (bool, error) {
ticker := time.NewTicker(time.Second * time.Duration(interval))
defer ticker.Stop()
timer := time.NewTimer(time.Second * time.Duration(timeout))
defer timer.Stop()
wantedStatuses := make([]aurora.ScheduleStatus, 0)
for status := range desiredStatuses {
wantedStatuses = append(wantedStatuses, status)
desiredStatuses []aurora.ScheduleStatus,
interval, timeout time.Duration) (bool, error) {
if interval < 1*time.Second {
interval = interval * time.Second
}
if timeout < 1*time.Second {
timeout = timeout * time.Second
}
ticker := time.NewTicker(interval)
defer ticker.Stop()
timer := time.NewTimer(timeout)
defer timer.Stop()
for {
select {
case <-ticker.C:
// Query Aurora for the state of the job key ever interval
instCount, cliErr := m.Client.GetInstanceIds(key, wantedStatuses)
instCount, cliErr := c.GetInstanceIds(key, desiredStatuses)
if cliErr != nil {
return false, errors.Wrap(cliErr, "Unable to communicate with Aurora")
}
@ -221,18 +171,23 @@ func (m *Monitor) ScheduleStatus(
case <-timer.C:
// If the timer runs out, return a timeout error to user
return false, newTimedoutError(errors.New("schedule status monitor timed out"))
return false, newTimedoutError(errors.New("schedule status monitor timedout"))
}
}
}
// HostMaintenance will monitor host status until all hosts match the status provided.
// Returns a map where the value is true if the host
// Monitor host status until all hosts match the status provided. Returns a map where the value is true if the host
// is in one of the desired mode(s) or false if it is not as of the time when the monitor exited.
func (m *Monitor) HostMaintenance(
hosts []string,
func (c *Client) MonitorHostMaintenance(hosts []string,
modes []aurora.MaintenanceMode,
interval, timeout int) (map[string]bool, error) {
interval, timeout time.Duration) (map[string]bool, error) {
if interval < 1*time.Second {
interval = interval * time.Second
}
if timeout < 1*time.Second {
timeout = timeout * time.Second
}
// Transform modes to monitor for into a set for easy lookup
desiredMode := make(map[aurora.MaintenanceMode]struct{})
@ -241,8 +196,7 @@ func (m *Monitor) HostMaintenance(
}
// Turn slice into a host set to eliminate duplicates.
// We also can't use a simple count because multiple modes means
// we can have multiple matches for a single host.
// We also can't use a simple count because multiple modes means we can have multiple matches for a single host.
// I.e. host A transitions from ACTIVE to DRAINING to DRAINED while monitored
remainingHosts := make(map[string]struct{})
for _, host := range hosts {
@ -251,16 +205,16 @@ func (m *Monitor) HostMaintenance(
hostResult := make(map[string]bool)
ticker := time.NewTicker(time.Second * time.Duration(interval))
ticker := time.NewTicker(interval)
defer ticker.Stop()
timer := time.NewTimer(time.Second * time.Duration(timeout))
timer := time.NewTimer(timeout)
defer timer.Stop()
for {
select {
case <-ticker.C:
// Client call has multiple retries internally
_, result, err := m.Client.MaintenanceStatus(hosts...)
result, err := c.MaintenanceStatus(hosts...)
if err != nil {
// Error is either a payload error or a severe connection error
for host := range remainingHosts {
@ -286,7 +240,73 @@ func (m *Monitor) HostMaintenance(
hostResult[host] = false
}
return hostResult, newTimedoutError(errors.New("host maintenance monitor timed out"))
return hostResult, newTimedoutError(errors.New("host maintenance monitor timedout"))
}
}
}
// MonitorAutoPausedUpdate is a special monitor for auto pause enabled batch updates. This monitor ensures that the update
// being monitored is capable of auto pausing and has auto pausing enabled. After verifying this information,
// the monitor watches for the job to enter the ROLL_FORWARD_PAUSED state and calculates the current batch
// the update is in using information from the update configuration.
func (c *Client) MonitorAutoPausedUpdate(key aurora.JobUpdateKey, interval, timeout time.Duration) (int, error) {
key.Job = &aurora.JobKey{
Role: key.Job.Role,
Environment: key.Job.Environment,
Name: key.Job.Name,
}
query := aurora.JobUpdateQuery{
UpdateStatuses: aurora.ACTIVE_JOB_UPDATE_STATES,
Limit: 1,
Key: &key,
}
updateDetails, err := c.JobUpdateDetails(query)
if err != nil {
return -1, errors.Wrap(err, "unable to get information about update")
}
if len(updateDetails) == 0 {
return -1, errors.Errorf("details for update could not be found")
}
updateStrategy := updateDetails[0].Update.Instructions.Settings.UpdateStrategy
var batchSizes []int32
switch {
case updateStrategy.IsSetVarBatchStrategy():
batchSizes = updateStrategy.VarBatchStrategy.GroupSizes
if !updateStrategy.VarBatchStrategy.AutopauseAfterBatch {
return -1, errors.Errorf("update does not have auto pause enabled")
}
case updateStrategy.IsSetBatchStrategy():
batchSizes = []int32{updateStrategy.BatchStrategy.GroupSize}
if !updateStrategy.BatchStrategy.AutopauseAfterBatch {
return -1, errors.Errorf("update does not have auto pause enabled")
}
default:
return -1, errors.Errorf("update is not using a batch update strategy")
}
query.UpdateStatuses = append(TerminalUpdateStates(), aurora.JobUpdateStatus_ROLL_FORWARD_PAUSED)
summary, err := c.MonitorJobUpdateQuery(query, interval, timeout)
if err != nil {
return -1, err
}
// Summary 0 is assumed to exist because MonitorJobUpdateQuery will return an error if there is no summaries
if !(summary[0].State.Status == aurora.JobUpdateStatus_ROLL_FORWARD_PAUSED ||
summary[0].State.Status == aurora.JobUpdateStatus_ROLLED_FORWARD) {
return -1, errors.Errorf("update is in a terminal state %v", summary[0].State.Status)
}
updatingInstances := make(map[int32]struct{})
for _, e := range updateDetails[0].InstanceEvents {
// We only care about INSTANCE_UPDATING actions because we only care that they've been attempted
if e != nil && e.GetAction() == aurora.JobUpdateAction_INSTANCE_UPDATING {
updatingInstances[e.GetInstanceId()] = struct{}{}
}
}
return calculateCurrentBatch(int32(len(updatingInstances)), batchSizes), nil
}

434
offer.go Normal file
View file

@ -0,0 +1,434 @@
/**
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package realis
import (
"bytes"
"crypto/tls"
"encoding/json"
"fmt"
"net/http"
"strings"
"github.com/aurora-scheduler/gorealis/v2/gen-go/apache/aurora"
)
// Offers on [aurora-scheduler]/offers endpoint
type Offer struct {
ID struct {
Value string `json:"value"`
} `json:"id"`
FrameworkID struct {
Value string `json:"value"`
} `json:"framework_id"`
AgentID struct {
Value string `json:"value"`
} `json:"agent_id"`
Hostname string `json:"hostname"`
URL struct {
Scheme string `json:"scheme"`
Address struct {
Hostname string `json:"hostname"`
IP string `json:"ip"`
Port int `json:"port"`
} `json:"address"`
Path string `json:"path"`
Query []interface{} `json:"query"`
} `json:"url"`
Resources []struct {
Name string `json:"name"`
Type string `json:"type"`
Ranges struct {
Range []struct {
Begin int `json:"begin"`
End int `json:"end"`
} `json:"range"`
} `json:"ranges,omitempty"`
Role string `json:"role"`
Reservations []interface{} `json:"reservations"`
Scalar struct {
Value float64 `json:"value"`
} `json:"scalar,omitempty"`
} `json:"resources"`
Attributes []struct {
Name string `json:"name"`
Type string `json:"type"`
Text struct {
Value string `json:"value"`
} `json:"text"`
} `json:"attributes"`
ExecutorIds []struct {
Value string `json:"value"`
} `json:"executor_ids"`
}
// hosts on [aurora-scheduler]/maintenance endpoint
type MaintenanceList struct {
Drained []string `json:"DRAINED"`
Scheduled []string `json:"SCHEDULED"`
Draining map[string][]string `json:"DRAINING"`
}
type OfferCount map[float64]int64
type OfferGroupReport map[string]OfferCount
type OfferReport map[string]OfferGroupReport
// MaintenanceHosts list all the hosts under maintenance
func (c *Client) MaintenanceHosts() ([]string, error) {
tr := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: c.config.insecureSkipVerify},
}
request := &http.Client{Transport: tr}
resp, err := request.Get(fmt.Sprintf("%s/maintenance", c.GetSchedulerURL()))
if err != nil {
return nil, err
}
defer resp.Body.Close()
buf := new(bytes.Buffer)
if _, err := buf.ReadFrom(resp.Body); err != nil {
return nil, err
}
var list MaintenanceList
if err := json.Unmarshal(buf.Bytes(), &list); err != nil {
return nil, err
}
hosts := append(list.Drained, list.Scheduled...)
for drainingHost := range list.Draining {
hosts = append(hosts, drainingHost)
}
return hosts, nil
}
// Offers pulls data from /offers endpoint
func (c *Client) Offers() ([]Offer, error) {
tr := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: c.config.insecureSkipVerify},
}
request := &http.Client{Transport: tr}
resp, err := request.Get(fmt.Sprintf("%s/offers", c.GetSchedulerURL()))
if err != nil {
return []Offer{}, err
}
defer resp.Body.Close()
buf := new(bytes.Buffer)
if _, err := buf.ReadFrom(resp.Body); err != nil {
return nil, err
}
var offers []Offer
if err := json.Unmarshal(buf.Bytes(), &offers); err != nil {
return []Offer{}, err
}
return offers, nil
}
// AvailOfferReport returns a detailed summary of offers available for use.
// For example, 2 nodes offer 32 cpus and 10 nodes offer 1 cpus.
func (c *Client) AvailOfferReport() (OfferReport, error) {
maintHosts, err := c.MaintenanceHosts()
if err != nil {
return nil, err
}
maintHostSet := map[string]bool{}
for _, h := range maintHosts {
maintHostSet[h] = true
}
// Get a list of offers
offers, err := c.Offers()
if err != nil {
return nil, err
}
report := OfferReport{}
for _, o := range offers {
if maintHostSet[o.Hostname] {
continue
}
group := "non-dedicated"
for _, a := range o.Attributes {
if a.Name == "dedicated" {
group = a.Text.Value
break
}
}
if _, ok := report[group]; !ok {
report[group] = map[string]OfferCount{}
}
for _, r := range o.Resources {
if _, ok := report[group][r.Name]; !ok {
report[group][r.Name] = OfferCount{}
}
val := 0.0
switch r.Type {
case "SCALAR":
val = r.Scalar.Value
case "RANGES":
for _, pr := range r.Ranges.Range {
val += float64(pr.End - pr.Begin + 1)
}
default:
return nil, fmt.Errorf("%s is not supported", r.Type)
}
report[group][r.Name][val]++
}
}
return report, nil
}
// FitTasks computes the number tasks can be fit in a list of offer
func (c *Client) FitTasks(taskConfig *aurora.TaskConfig, offers []Offer) (int64, error) {
// count the number of tasks per limit contraint: limit.name -> limit.value -> count
limitCounts := map[string]map[string]int64{}
for _, c := range taskConfig.Constraints {
if c.Constraint.Limit != nil {
limitCounts[c.Name] = map[string]int64{}
}
}
request := ResourcesToMap(taskConfig.Resources)
// validate resource request
if len(request) == 0 {
return -1, fmt.Errorf("Resource request %v must not be empty", request)
}
isValid := false
for _, resVal := range request {
if resVal > 0 {
isValid = true
break
}
}
if !isValid {
return -1, fmt.Errorf("Resource request %v is not valid", request)
}
// pull the list of hosts under maintenance
maintHosts, err := c.MaintenanceHosts()
if err != nil {
return -1, err
}
maintHostSet := map[string]bool{}
for _, h := range maintHosts {
maintHostSet[h] = true
}
numTasks := int64(0)
for _, o := range offers {
// skip the hosts under maintenance
if maintHostSet[o.Hostname] {
continue
}
numTasksPerOffer := int64(-1)
for resName, resVal := range request {
// skip as we can fit a infinite number of tasks with 0 demand.
if resVal == 0 {
continue
}
avail := 0.0
for _, r := range o.Resources {
if r.Name != resName {
continue
}
switch r.Type {
case "SCALAR":
avail = r.Scalar.Value
case "RANGES":
for _, pr := range r.Ranges.Range {
avail += float64(pr.End - pr.Begin + 1)
}
default:
return -1, fmt.Errorf("%s is not supported", r.Type)
}
}
numTasksPerResource := int64(avail / resVal)
if numTasksPerResource < numTasksPerOffer || numTasksPerOffer < 0 {
numTasksPerOffer = numTasksPerResource
}
}
numTasks += fitConstraints(taskConfig, &o, limitCounts, numTasksPerOffer)
}
return numTasks, nil
}
func fitConstraints(taskConfig *aurora.TaskConfig,
offer *Offer,
limitCounts map[string]map[string]int64,
numTasksPerOffer int64) int64 {
// check dedicated attributes vs. constraints
if !isDedicated(offer, taskConfig.Job.Role, taskConfig.Constraints) {
return 0
}
limitConstraints := []*aurora.Constraint{}
for _, c := range taskConfig.Constraints {
// look for corresponding attribute
attFound := false
for _, a := range offer.Attributes {
if a.Name == c.Name {
attFound = true
}
}
// constraint not found in offer's attributes
if !attFound {
return 0
}
if c.Constraint.Value != nil && !valueConstraint(offer, c) {
// value constraint is not satisfied
return 0
} else if c.Constraint.Limit != nil {
limitConstraints = append(limitConstraints, c)
limit := limitConstraint(offer, c, limitCounts)
if numTasksPerOffer > limit && limit >= 0 {
numTasksPerOffer = limit
}
}
}
// update limitCounts
for _, c := range limitConstraints {
for _, a := range offer.Attributes {
if a.Name == c.Name {
limitCounts[a.Name][a.Text.Value] += numTasksPerOffer
}
}
}
return numTasksPerOffer
}
func isDedicated(offer *Offer, role string, constraints []*aurora.Constraint) bool {
// get all dedicated attributes of an offer
dedicatedAtts := map[string]bool{}
for _, a := range offer.Attributes {
if a.Name == "dedicated" {
dedicatedAtts[a.Text.Value] = true
}
}
if len(dedicatedAtts) == 0 {
return true
}
// check if constraints are matching dedicated attributes
matched := false
for _, c := range constraints {
if c.Name == "dedicated" && c.Constraint.Value != nil {
found := false
for _, v := range c.Constraint.Value.Values {
if dedicatedAtts[v] && strings.HasPrefix(v, fmt.Sprintf("%s/", role)) {
found = true
break
}
}
if found {
matched = true
} else {
return false
}
}
}
return matched
}
// valueConstraint checks Value Contraints of task if the are matched by the offer.
// more details can be found here https://aurora.apache.org/documentation/latest/features/constraints/
func valueConstraint(offer *Offer, constraint *aurora.Constraint) bool {
matched := false
for _, a := range offer.Attributes {
if a.Name == constraint.Name {
for _, v := range constraint.Constraint.Value.Values {
matched = (a.Text.Value == v && !constraint.Constraint.Value.Negated) ||
(a.Text.Value != v && constraint.Constraint.Value.Negated)
if matched {
break
}
}
if matched {
break
}
}
}
return matched
}
// limitConstraint limits the number of pods on a group which has the same attribute.
// more details can be found here https://aurora.apache.org/documentation/latest/features/constraints/
func limitConstraint(offer *Offer, constraint *aurora.Constraint, limitCounts map[string]map[string]int64) int64 {
limit := int64(-1)
for _, a := range offer.Attributes {
// limit constraint found
if a.Name == constraint.Name {
curr := limitCounts[a.Name][a.Text.Value]
currLimit := int64(constraint.Constraint.Limit.Limit)
if curr >= currLimit {
return 0
}
if currLimit-curr < limit || limit < 0 {
limit = currLimit - curr
}
}
}
return limit
}

1050
realis.go

File diff suppressed because it is too large Load diff

View file

@ -1,56 +1,30 @@
/**
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package realis
import (
"context"
"github.com/paypal/gorealis/gen-go/apache/aurora"
"github.com/aurora-scheduler/gorealis/v2/gen-go/apache/aurora"
"github.com/pkg/errors"
)
// TODO(rdelvalle): Consider moving these functions to another interface. It would be a backwards incompatible change,
// but would add safety.
// Set a list of nodes to DRAINING. This means nothing will be able to be scheduled on them and any existing
// tasks will be killed and re-scheduled elsewhere in the cluster. Tasks from DRAINING nodes are not guaranteed
// to return to running unless there is enough capacity in the cluster to run them.
func (r *realisClient) DrainHosts(hosts ...string) (*aurora.Response, *aurora.DrainHostsResult_, error) {
var result *aurora.DrainHostsResult_
if len(hosts) == 0 {
return nil, nil, errors.New("no hosts provided to drain")
}
drainList := aurora.NewHosts()
drainList.HostNames = hosts
r.logger.debugPrintf("DrainHosts Thrift Payload: %v\n", drainList)
resp, retryErr := r.thriftCallWithRetries(
false,
func() (*aurora.Response, error) {
return r.adminClient.DrainHosts(context.TODO(), drainList)
})
if retryErr != nil {
return resp, result, errors.Wrap(retryErr, "Unable to recover connection")
}
if resp.GetResult_() != nil {
result = resp.GetResult_().GetDrainHostsResult_()
}
return resp, result, nil
}
// Start SLA Aware Drain.
// defaultSlaPolicy is the fallback SlaPolicy to use if a task does not have an SlaPolicy.
// After timeoutSecs, tasks will be forcefully drained without checking SLA.
func (r *realisClient) SLADrainHosts(
policy *aurora.SlaPolicy,
timeout int64,
hosts ...string) (*aurora.DrainHostsResult_, error) {
var result *aurora.DrainHostsResult_
func (c *Client) DrainHosts(hosts ...string) ([]*aurora.HostStatus, error) {
if len(hosts) == 0 {
return nil, errors.New("no hosts provided to drain")
@ -59,216 +33,261 @@ func (r *realisClient) SLADrainHosts(
drainList := aurora.NewHosts()
drainList.HostNames = hosts
r.logger.debugPrintf("SLADrainHosts Thrift Payload: %v\n", drainList)
c.logger.DebugPrintf("DrainHosts Thrift Payload: %v\n", drainList)
resp, retryErr := r.thriftCallWithRetries(
false,
func() (*aurora.Response, error) {
return r.adminClient.SlaDrainHosts(context.TODO(), drainList, policy, timeout)
})
resp, retryErr := c.thriftCallWithRetries(false, func() (*aurora.Response, error) {
return c.adminClient.DrainHosts(context.TODO(), drainList)
},
nil,
)
if retryErr != nil {
return result, errors.Wrap(retryErr, "Unable to recover connection")
return nil, errors.Wrap(retryErr, "unable to recover connection")
}
if resp.GetResult_() != nil {
result = resp.GetResult_().GetDrainHostsResult_()
if resp == nil || resp.GetResult_() == nil || resp.GetResult_().GetDrainHostsResult_() == nil {
return nil, errors.New("unexpected response from scheduler")
}
return result, nil
return resp.GetResult_().GetDrainHostsResult_().GetStatuses(), nil
}
func (r *realisClient) StartMaintenance(hosts ...string) (*aurora.Response, *aurora.StartMaintenanceResult_, error) {
var result *aurora.StartMaintenanceResult_
// Start SLA Aware Drain.
// defaultSlaPolicy is the fallback SlaPolicy to use if a task does not have an SlaPolicy.
// After timeoutSecs, tasks will be forcefully drained without checking SLA.
func (c *Client) SLADrainHosts(policy *aurora.SlaPolicy, timeout int64, hosts ...string) ([]*aurora.HostStatus, error) {
if len(hosts) == 0 {
return nil, nil, errors.New("no hosts provided to start maintenance on")
return nil, errors.New("no hosts provided to drain")
}
if policy == nil || policy.CountSetFieldsSlaPolicy() == 0 {
policy = &defaultSlaPolicy
c.logger.Printf("Warning: start draining with default sla policy %v", policy)
}
if timeout < 0 {
c.logger.Printf("Warning: timeout %d secs is invalid, draining with default timeout %d secs",
timeout,
defaultSlaDrainTimeoutSecs)
timeout = defaultSlaDrainTimeoutSecs
}
drainList := aurora.NewHosts()
drainList.HostNames = hosts
c.logger.DebugPrintf("SLADrainHosts Thrift Payload: %v\n", drainList)
resp, retryErr := c.thriftCallWithRetries(false, func() (*aurora.Response, error) {
return c.adminClient.SlaDrainHosts(context.TODO(), drainList, policy, timeout)
},
nil,
)
if retryErr != nil {
return nil, errors.Wrap(retryErr, "unable to recover connection")
}
if resp == nil || resp.GetResult_() == nil || resp.GetResult_().GetDrainHostsResult_() == nil {
return nil, errors.New("unexpected response from scheduler")
}
return resp.GetResult_().GetDrainHostsResult_().GetStatuses(), nil
}
func (c *Client) StartMaintenance(hosts ...string) ([]*aurora.HostStatus, error) {
if len(hosts) == 0 {
return nil, errors.New("no hosts provided to start maintenance on")
}
hostList := aurora.NewHosts()
hostList.HostNames = hosts
r.logger.debugPrintf("StartMaintenance Thrift Payload: %v\n", hostList)
c.logger.DebugPrintf("StartMaintenance Thrift Payload: %v\n", hostList)
resp, retryErr := r.thriftCallWithRetries(
false,
func() (*aurora.Response, error) {
return r.adminClient.StartMaintenance(context.TODO(), hostList)
})
resp, retryErr := c.thriftCallWithRetries(false, func() (*aurora.Response, error) {
return c.adminClient.StartMaintenance(context.TODO(), hostList)
},
nil,
)
if retryErr != nil {
return resp, result, errors.Wrap(retryErr, "Unable to recover connection")
return nil, errors.Wrap(retryErr, "unable to recover connection")
}
if resp.GetResult_() != nil {
result = resp.GetResult_().GetStartMaintenanceResult_()
if resp == nil || resp.GetResult_() == nil || resp.GetResult_().GetStartMaintenanceResult_() == nil {
return nil, errors.New("unexpected response from scheduler")
}
return resp, result, nil
return resp.GetResult_().GetStartMaintenanceResult_().GetStatuses(), nil
}
func (r *realisClient) EndMaintenance(hosts ...string) (*aurora.Response, *aurora.EndMaintenanceResult_, error) {
var result *aurora.EndMaintenanceResult_
func (c *Client) EndMaintenance(hosts ...string) ([]*aurora.HostStatus, error) {
if len(hosts) == 0 {
return nil, nil, errors.New("no hosts provided to end maintenance on")
return nil, errors.New("no hosts provided to end maintenance on")
}
hostList := aurora.NewHosts()
hostList.HostNames = hosts
r.logger.debugPrintf("EndMaintenance Thrift Payload: %v\n", hostList)
c.logger.DebugPrintf("EndMaintenance Thrift Payload: %v\n", hostList)
resp, retryErr := r.thriftCallWithRetries(
false,
func() (*aurora.Response, error) {
return r.adminClient.EndMaintenance(context.TODO(), hostList)
})
resp, retryErr := c.thriftCallWithRetries(false, func() (*aurora.Response, error) {
return c.adminClient.EndMaintenance(context.TODO(), hostList)
},
nil,
)
if retryErr != nil {
return resp, result, errors.Wrap(retryErr, "Unable to recover connection")
return nil, errors.Wrap(retryErr, "unable to recover connection")
}
if resp.GetResult_() != nil {
result = resp.GetResult_().GetEndMaintenanceResult_()
if resp == nil || resp.GetResult_() == nil || resp.GetResult_().GetEndMaintenanceResult_() == nil {
return nil, errors.New("unexpected response from scheduler")
}
return resp, result, nil
return resp.GetResult_().GetEndMaintenanceResult_().GetStatuses(), nil
}
func (r *realisClient) MaintenanceStatus(hosts ...string) (*aurora.Response, *aurora.MaintenanceStatusResult_, error) {
var result *aurora.MaintenanceStatusResult_
func (c *Client) MaintenanceStatus(hosts ...string) (*aurora.MaintenanceStatusResult_, error) {
if len(hosts) == 0 {
return nil, nil, errors.New("no hosts provided to get maintenance status from")
return nil, errors.New("no hosts provided to get maintenance status from")
}
hostList := aurora.NewHosts()
hostList.HostNames = hosts
r.logger.debugPrintf("MaintenanceStatus Thrift Payload: %v\n", hostList)
c.logger.DebugPrintf("MaintenanceStatus Thrift Payload: %v\n", hostList)
// Make thrift call. If we encounter an error sending the call, attempt to reconnect
// and continue trying to resend command until we run out of retries.
resp, retryErr := r.thriftCallWithRetries(
false,
func() (*aurora.Response, error) {
return r.adminClient.MaintenanceStatus(context.TODO(), hostList)
})
resp, retryErr := c.thriftCallWithRetries(false, func() (*aurora.Response, error) {
return c.adminClient.MaintenanceStatus(context.TODO(), hostList)
},
nil,
)
if retryErr != nil {
return resp, result, errors.Wrap(retryErr, "Unable to recover connection")
return nil, errors.Wrap(retryErr, "unable to recover connection")
}
if resp == nil || resp.GetResult_() == nil {
return nil, errors.New("unexpected response from scheduler")
}
if resp.GetResult_() != nil {
result = resp.GetResult_().GetMaintenanceStatusResult_()
}
return resp, result, nil
return resp.GetResult_().GetMaintenanceStatusResult_(), nil
}
// SetQuota sets a quota aggregate for the given role
// TODO(zircote) Currently investigating an error that is returned
// from thrift calls that include resources for `NamedPort` and `NumGpu`
func (r *realisClient) SetQuota(role string, cpu *float64, ramMb *int64, diskMb *int64) (*aurora.Response, error) {
quota := &aurora.ResourceAggregate{
Resources: []*aurora.Resource{{NumCpus: cpu}, {RamMb: ramMb}, {DiskMb: diskMb}},
}
// TODO(zircote) Currently investigating an error that is returned from thrift calls that include resources for `NamedPort` and `NumGpu`
func (c *Client) SetQuota(role string, cpu *float64, ramMb *int64, diskMb *int64) error {
ramResource := aurora.NewResource()
ramResource.RamMb = ramMb
cpuResource := aurora.NewResource()
cpuResource.NumCpus = cpu
diskResource := aurora.NewResource()
diskResource.DiskMb = diskMb
resp, retryErr := r.thriftCallWithRetries(
false,
func() (*aurora.Response, error) {
return r.adminClient.SetQuota(context.TODO(), role, quota)
})
quota := aurora.NewResourceAggregate()
quota.Resources = []*aurora.Resource{ramResource, cpuResource, diskResource}
_, retryErr := c.thriftCallWithRetries(false, func() (*aurora.Response, error) {
return c.adminClient.SetQuota(context.TODO(), role, quota)
},
nil,
)
if retryErr != nil {
return resp, errors.Wrap(retryErr, "Unable to set role quota")
return errors.Wrap(retryErr, "unable to set role quota")
}
return resp, retryErr
return retryErr
}
// GetQuota returns the resource aggregate for the given role
func (r *realisClient) GetQuota(role string) (*aurora.Response, error) {
func (c *Client) GetQuota(role string) (*aurora.GetQuotaResult_, error) {
resp, retryErr := r.thriftCallWithRetries(
false,
func() (*aurora.Response, error) {
return r.adminClient.GetQuota(context.TODO(), role)
})
resp, retryErr := c.thriftCallWithRetries(false, func() (*aurora.Response, error) {
return c.adminClient.GetQuota(context.TODO(), role)
},
nil,
)
if retryErr != nil {
return resp, errors.Wrap(retryErr, "Unable to get role quota")
return nil, errors.Wrap(retryErr, "unable to get role quota")
}
return resp, retryErr
if resp == nil || resp.GetResult_() == nil {
return nil, errors.New("unexpected response from scheduler")
}
return resp.GetResult_().GetGetQuotaResult_(), nil
}
// Force Aurora Scheduler to perform a snapshot and write to Mesos log
func (r *realisClient) Snapshot() error {
func (c *Client) Snapshot() error {
_, retryErr := r.thriftCallWithRetries(
false,
func() (*aurora.Response, error) {
return r.adminClient.Snapshot(context.TODO())
})
_, retryErr := c.thriftCallWithRetries(false, func() (*aurora.Response, error) {
return c.adminClient.Snapshot(context.TODO())
},
nil,
)
if retryErr != nil {
return errors.Wrap(retryErr, "Unable to recover connection")
return errors.Wrap(retryErr, "unable to recover connection")
}
return nil
}
// Force Aurora Scheduler to write backup file to a file in the backup directory
func (r *realisClient) PerformBackup() error {
func (c *Client) PerformBackup() error {
_, retryErr := r.thriftCallWithRetries(
false,
func() (*aurora.Response, error) {
return r.adminClient.PerformBackup(context.TODO())
})
_, retryErr := c.thriftCallWithRetries(false, func() (*aurora.Response, error) {
return c.adminClient.PerformBackup(context.TODO())
},
nil,
)
if retryErr != nil {
return errors.Wrap(retryErr, "Unable to recover connection")
return errors.Wrap(retryErr, "unable to recover connection")
}
return nil
}
func (r *realisClient) ForceImplicitTaskReconciliation() error {
// Force an Implicit reconciliation between Mesos and Aurora
func (c *Client) ForceImplicitTaskReconciliation() error {
_, retryErr := r.thriftCallWithRetries(
false,
func() (*aurora.Response, error) {
return r.adminClient.TriggerImplicitTaskReconciliation(context.TODO())
})
_, retryErr := c.thriftCallWithRetries(false, func() (*aurora.Response, error) {
return c.adminClient.TriggerImplicitTaskReconciliation(context.TODO())
},
nil,
)
if retryErr != nil {
return errors.Wrap(retryErr, "Unable to recover connection")
return errors.Wrap(retryErr, "unable to recover connection")
}
return nil
}
func (r *realisClient) ForceExplicitTaskReconciliation(batchSize *int32) error {
// Force an Explicit reconciliation between Mesos and Aurora
func (c *Client) ForceExplicitTaskReconciliation(batchSize *int32) error {
if batchSize != nil && *batchSize < 1 {
return errors.New("invalid batch size")
return errors.New("invalid batch size.")
}
settings := aurora.NewExplicitReconciliationSettings()
settings.BatchSize = batchSize
_, retryErr := r.thriftCallWithRetries(false,
func() (*aurora.Response, error) {
return r.adminClient.TriggerExplicitTaskReconciliation(context.TODO(), settings)
})
_, retryErr := c.thriftCallWithRetries(false, func() (*aurora.Response, error) {
return c.adminClient.TriggerExplicitTaskReconciliation(context.TODO(), settings)
},
nil,
)
if retryErr != nil {
return errors.Wrap(retryErr, "Unable to recover connection")
return errors.Wrap(retryErr, "unable to recover connection")
}
return nil

181
realis_config.go Normal file
View file

@ -0,0 +1,181 @@
/**
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package realis
import (
"strings"
"time"
"github.com/apache/thrift/lib/go/thrift"
"github.com/aurora-scheduler/gorealis/v2/gen-go/apache/aurora"
)
type clientConfig struct {
username, password string
url string
timeout time.Duration
transportProtocol TransportProtocol
cluster *Cluster
backoff Backoff
transport thrift.TTransport
protoFactory thrift.TProtocolFactory
logger *LevelLogger
insecureSkipVerify bool
certsPath string
clientKey, clientCert string
options []ClientOption
debug bool
trace bool
zkOptions []ZKOpt
failOnPermanentErrors bool
}
var defaultBackoff = Backoff{
Steps: 3,
Duration: 10 * time.Second,
Factor: 5.0,
Jitter: 0.1,
}
var defaultSlaPolicy = aurora.SlaPolicy{
PercentageSlaPolicy: &aurora.PercentageSlaPolicy{
Percentage: 66,
DurationSecs: 300,
},
}
const defaultSlaDrainTimeoutSecs = 900
type TransportProtocol int
const (
unsetProtocol TransportProtocol = iota
jsonProtocol
binaryProtocol
)
type ClientOption func(*clientConfig)
// clientConfig sets for options in clientConfig.
func BasicAuth(username, password string) ClientOption {
return func(config *clientConfig) {
config.username = username
config.password = password
}
}
func SchedulerUrl(url string) ClientOption {
return func(config *clientConfig) {
config.url = url
}
}
func Timeout(timeout time.Duration) ClientOption {
return func(config *clientConfig) {
config.timeout = timeout
}
}
func ZKCluster(cluster *Cluster) ClientOption {
return func(config *clientConfig) {
config.cluster = cluster
}
}
func ZKUrl(url string) ClientOption {
opts := []ZKOpt{ZKEndpoints(strings.Split(url, ",")...), ZKPath("/aurora/scheduler")}
return func(config *clientConfig) {
if config.zkOptions == nil {
config.zkOptions = opts
} else {
config.zkOptions = append(config.zkOptions, opts...)
}
}
}
func ThriftJSON() ClientOption {
return func(config *clientConfig) {
config.transportProtocol = jsonProtocol
}
}
func ThriftBinary() ClientOption {
return func(config *clientConfig) {
config.transportProtocol = binaryProtocol
}
}
func BackOff(b Backoff) ClientOption {
return func(config *clientConfig) {
config.backoff = b
}
}
func InsecureSkipVerify(InsecureSkipVerify bool) ClientOption {
return func(config *clientConfig) {
config.insecureSkipVerify = InsecureSkipVerify
}
}
func CertsPath(certspath string) ClientOption {
return func(config *clientConfig) {
config.certsPath = certspath
}
}
func ClientCerts(clientKey, clientCert string) ClientOption {
return func(config *clientConfig) {
config.clientKey, config.clientCert = clientKey, clientCert
}
}
// Use this option if you'd like to override default settings for connecting to Zookeeper.
// See zk.go for what is possible to set as an option.
func ZookeeperOptions(opts ...ZKOpt) ClientOption {
return func(config *clientConfig) {
config.zkOptions = opts
}
}
// Using the word set to avoid name collision with Interface.
func SetLogger(l Logger) ClientOption {
return func(config *clientConfig) {
config.logger = &LevelLogger{Logger: l}
}
}
// Enable debug statements.
func Debug() ClientOption {
return func(config *clientConfig) {
config.debug = true
}
}
// Enable trace statements.
func Trace() ClientOption {
return func(config *clientConfig) {
config.trace = true
}
}
// FailOnPermanentErrors - If the client encounters a connection error the standard library
// considers permanent, stop retrying and return an error to the user.
func FailOnPermanentErrors() ClientOption {
return func(config *clientConfig) {
config.failOnPermanentErrors = true
}
}

File diff suppressed because it is too large Load diff

72
realis_test.go Normal file
View file

@ -0,0 +1,72 @@
/**
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package realis
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGetCACerts(t *testing.T) {
certs, err := GetCerts("./examples/certs")
require.NoError(t, err)
assert.Equal(t, len(certs.Subjects()), 2)
}
func TestAuroraURLValidator(t *testing.T) {
t.Run("badURL", func(t *testing.T) {
url, err := validateAuroraAddress("http://badurl.com/badpath")
assert.Empty(t, url)
assert.Error(t, err)
})
t.Run("URLHttp", func(t *testing.T) {
url, err := validateAuroraAddress("http://goodurl.com:8081/api")
assert.Equal(t, "http://goodurl.com:8081/api", url)
assert.NoError(t, err)
})
t.Run("URLHttps", func(t *testing.T) {
url, err := validateAuroraAddress("https://goodurl.com:8081/api")
assert.Equal(t, "https://goodurl.com:8081/api", url)
assert.NoError(t, err)
})
t.Run("URLNoPath", func(t *testing.T) {
url, err := validateAuroraAddress("http://goodurl.com:8081")
assert.Equal(t, "http://goodurl.com:8081/api", url)
assert.NoError(t, err)
})
t.Run("ipAddrNoPath", func(t *testing.T) {
url, err := validateAuroraAddress("http://192.168.1.33:8081")
assert.Equal(t, "http://192.168.1.33:8081/api", url)
assert.NoError(t, err)
})
t.Run("URLNoProtocol", func(t *testing.T) {
url, err := validateAuroraAddress("goodurl.com:8081/api")
assert.Equal(t, "http://goodurl.com:8081/api", url)
assert.NoError(t, err)
})
t.Run("URLNoProtocolNoPathNoPort", func(t *testing.T) {
url, err := validateAuroraAddress("goodurl.com")
assert.Equal(t, "http://goodurl.com:8081/api", url)
assert.NoError(t, err)
})
}

View file

@ -17,9 +17,8 @@ package response
import (
"bytes"
"errors"
"github.com/paypal/gorealis/gen-go/apache/aurora"
"github.com/aurora-scheduler/gorealis/v2/gen-go/apache/aurora"
)
// Get key from a response created by a StartJobUpdate call
@ -36,21 +35,13 @@ func ScheduleStatusResult(resp *aurora.Response) *aurora.ScheduleStatusResult_ {
}
func JobUpdateSummaries(resp *aurora.Response) []*aurora.JobUpdateSummary {
if resp == nil || resp.GetResult_() == nil || resp.GetResult_().GetGetJobUpdateSummariesResult_() == nil {
return nil
}
return resp.GetResult_().GetGetJobUpdateSummariesResult_().GetUpdateSummaries()
}
// Deprecated: Replaced by checks done inside of thriftCallHelper
func ResponseCodeCheck(resp *aurora.Response) (*aurora.Response, error) {
if resp == nil {
return resp, errors.New("Response is nil")
}
if resp.GetResponseCode() != aurora.ResponseCode_OK {
return resp, errors.New(CombineMessage(resp))
}
return resp, nil
}
// Based on aurora client: src/main/python/apache/aurora/client/base.py
func CombineMessage(resp *aurora.Response) string {
var buffer bytes.Buffer

198
retry.go
View file

@ -21,8 +21,8 @@ import (
"time"
"github.com/apache/thrift/lib/go/thrift"
"github.com/paypal/gorealis/gen-go/apache/aurora"
"github.com/paypal/gorealis/response"
"github.com/aurora-scheduler/gorealis/v2/gen-go/apache/aurora"
"github.com/aurora-scheduler/gorealis/v2/response"
"github.com/pkg/errors"
)
@ -61,10 +61,11 @@ type ConditionFunc func() (done bool, err error)
//
// If the condition never returns true, ErrWaitTimeout is returned. Errors
// do not cause the function to return.
func ExponentialBackoff(backoff Backoff, logger logger, condition ConditionFunc) error {
func ExponentialBackoff(backoff Backoff, logger Logger, condition ConditionFunc) error {
var err error
var ok bool
var curStep int
duration := backoff.Duration
for curStep = 0; curStep < backoff.Steps; curStep++ {
@ -76,8 +77,7 @@ func ExponentialBackoff(backoff Backoff, logger logger, condition ConditionFunc)
adjusted = Jitter(duration, backoff.Jitter)
}
logger.Printf(
"A retryable error occurred during function call, backing off for %v before retrying\n", adjusted)
logger.Printf("A retryable error occurred during function call, backing off for %v before retrying\n", adjusted)
time.Sleep(adjusted)
duration = time.Duration(float64(duration) * backoff.Factor)
}
@ -114,17 +114,23 @@ func ExponentialBackoff(backoff Backoff, logger logger, condition ConditionFunc)
type auroraThriftCall func() (resp *aurora.Response, err error)
// Duplicates the functionality of ExponentialBackoff but is specifically targeted towards ThriftCalls.
func (r *realisClient) thriftCallWithRetries(
returnOnTimeout bool,
thriftCall auroraThriftCall) (*aurora.Response, error) {
// verifyOntimeout defines the type of function that will be used to verify whether a Thirft call to the Scheduler
// made it to the scheduler or not. In general, these types of functions will have to interact with the scheduler
// through the very same Thrift API which previously encountered a time-out from the client.
// This means that the functions themselves should be kept to a minimum number of Thrift calls.
// It should also be noted that this is a best effort mechanism and
// is likely to fail for the same reasons that the original call failed.
type verifyOnTimeout func() (*aurora.Response, bool)
// Duplicates the functionality of ExponentialBackoff but is specifically targeted towards ThriftCalls.
func (c *Client) thriftCallWithRetries(returnOnTimeout bool, thriftCall auroraThriftCall,
verifyOnTimeout verifyOnTimeout) (*aurora.Response, error) {
var resp *aurora.Response
var clientErr error
var curStep int
timeouts := 0
backoff := r.config.backoff
backoff := c.config.backoff
duration := backoff.Duration
for curStep = 0; curStep < backoff.Steps; curStep++ {
@ -136,8 +142,8 @@ func (r *realisClient) thriftCallWithRetries(
adjusted = Jitter(duration, backoff.Jitter)
}
r.logger.Printf(
"A retryable error occurred during thrift call, backing off for %v before retry %v\n",
c.logger.Printf(
"A retryable error occurred during thrift call, backing off for %v before retry %v",
adjusted,
curStep)
@ -149,101 +155,104 @@ func (r *realisClient) thriftCallWithRetries(
// Placing this in an anonymous function in order to create a new, short-lived stack allowing unlock
// to be run in case of a panic inside of thriftCall.
func() {
r.lock.Lock()
defer r.lock.Unlock()
c.lock.Lock()
defer c.lock.Unlock()
resp, clientErr = thriftCall()
r.logger.tracePrintf("Aurora Thrift Call ended resp: %v clientErr: %v\n", resp, clientErr)
c.logger.TracePrintf("Aurora Thrift Call ended resp: %v clientErr: %v", resp, clientErr)
}()
// Check if our thrift call is returning an error. This is a retryable event as we don't know
// if it was caused by network issues.
if clientErr != nil {
// Print out the error to the user
r.logger.Printf("Client Error: %v\n", clientErr)
c.logger.Printf("Client Error: %v", clientErr)
// Determine if error is a temporary URL error by going up the stack
e, ok := clientErr.(thrift.TTransportException)
if ok {
r.logger.debugPrint("Encountered a transport exception")
temporary, timedout := isConnectionError(clientErr)
if !temporary && c.RealisConfig().failOnPermanentErrors {
return nil, errors.Wrap(clientErr, "permanent connection error")
}
e, ok := e.Err().(*url.Error)
if ok {
// EOF error occurs when the server closes the read buffer of the client. This is common
// when the server is overloaded and should be retried. All other errors that are permanent
// will not be retried.
if e.Err != io.EOF && !e.Temporary() && r.RealisConfig().failOnPermanentErrors {
return nil, errors.Wrap(clientErr, "permanent connection error")
}
// Corner case where thrift payload was received by Aurora but connection timedout before Aurora was
// able to reply. In this case we will return whatever response was received and a TimedOut behaving
// error. Users can take special action on a timeout by using IsTimedout and reacting accordingly.
if e.Timeout() {
timeouts++
r.logger.debugPrintf(
"Client closed connection (timedout) %d times before server responded, "+
"consider increasing connection timeout",
timeouts)
if returnOnTimeout {
return resp, newTimedoutError(errors.New("client connection closed before server answer"))
}
}
}
// There exists a corner case where thrift payload was received by Aurora but
// connection timed out before Aurora was able to reply.
// Users can take special action on a timeout by using IsTimedout and reacting accordingly
// if they have configured the client to return on a timeout.
if timedout && returnOnTimeout {
return resp, newTimedoutError(errors.New("client connection closed before server answer"))
}
// In the future, reestablish connection should be able to check if it is actually possible
// to make a thrift call to Aurora. For now, a reconnect should always lead to a retry.
// Ignoring error due to the fact that an error should be retried regardless
reestablishErr := r.ReestablishConn()
reestablishErr := c.ReestablishConn()
if reestablishErr != nil {
r.logger.debugPrintf("error re-establishing connection ", reestablishErr)
}
} else {
// If there was no client error, but the response is nil, something went wrong.
// Ideally, we'll never encounter this but we're placing a safeguard here.
if resp == nil {
return nil, errors.New("response from aurora is nil")
c.logger.DebugPrintf("error re-establishing connection ", reestablishErr)
}
// Check Response Code from thrift and make a decision to continue retrying or not.
switch responseCode := resp.GetResponseCode(); responseCode {
// If users did not opt for a return on timeout in order to react to a timedout error,
// attempt to verify that the call made it to the scheduler after the connection was re-established.
if timedout {
timeouts++
c.logger.DebugPrintf(
"Client closed connection %d times before server responded, "+
"consider increasing connection timeout",
timeouts)
// If the thrift call succeeded, stop retrying
case aurora.ResponseCode_OK:
return resp, nil
// If the response code is transient, continue retrying
case aurora.ResponseCode_ERROR_TRANSIENT:
r.logger.Println("Aurora replied with Transient error code, retrying")
continue
// Failure scenarios, these indicate a bad payload or a bad config. Stop retrying.
case aurora.ResponseCode_INVALID_REQUEST,
aurora.ResponseCode_ERROR,
aurora.ResponseCode_AUTH_FAILED,
aurora.ResponseCode_JOB_UPDATING_ERROR:
r.logger.Printf("Terminal Response Code %v from Aurora, won't retry\n", resp.GetResponseCode().String())
return resp, errors.New(response.CombineMessage(resp))
// The only case that should fall down to here is a WARNING response code.
// It is currently not used as a response in the scheduler so it is unknown how to handle it.
default:
r.logger.debugPrintf("unhandled response code %v received from Aurora\n", responseCode)
return nil, errors.Errorf("unhandled response code from Aurora %v", responseCode.String())
// Allow caller to provide a function which checks if the original call was successful before
// it timed out.
if verifyOnTimeout != nil {
if verifyResp, ok := verifyOnTimeout(); ok {
c.logger.Print("verified that the call went through successfully after a client timeout")
// Response here might be different than the original as it is no longer constructed
// by the scheduler but mimicked.
// This is OK since the scheduler is very unlikely to change responses at this point in its
// development cycle but we must be careful to not return an incorrectly constructed response.
return verifyResp, nil
}
}
}
// Retry the thrift payload
continue
}
// If there was no client error, but the response is nil, something went wrong.
// Ideally, we'll never encounter this but we're placing a safeguard here.
if resp == nil {
return nil, errors.New("response from aurora is nil")
}
// Check Response Code from thrift and make a decision to continue retrying or not.
switch responseCode := resp.GetResponseCode(); responseCode {
// If the thrift call succeeded, stop retrying
case aurora.ResponseCode_OK:
return resp, nil
// If the response code is transient, continue retrying
case aurora.ResponseCode_ERROR_TRANSIENT:
c.logger.Println("Aurora replied with Transient error code, retrying")
continue
// Failure scenarios, these indicate a bad payload or a bad clientConfig. Stop retrying.
case aurora.ResponseCode_INVALID_REQUEST,
aurora.ResponseCode_ERROR,
aurora.ResponseCode_AUTH_FAILED,
aurora.ResponseCode_JOB_UPDATING_ERROR:
c.logger.Printf("Terminal Response Code %v from Aurora, won't retry\n", resp.GetResponseCode().String())
return resp, errors.New(response.CombineMessage(resp))
// The only case that should fall down to here is a WARNING response code.
// It is currently not used as a response in the scheduler so it is unknown how to handle it.
default:
c.logger.DebugPrintf("unhandled response code %v received from Aurora\n", responseCode)
return nil, errors.Errorf("unhandled response code from Aurora %v", responseCode.String())
}
}
r.logger.debugPrintf("it took %v retries to complete this operation\n", curStep)
if curStep > 1 {
r.config.logger.Printf("retried this thrift call %d time(s)", curStep)
c.config.logger.Printf("this thrift call was retried %d time(s)", curStep)
}
// Provide more information to the user wherever possible.
@ -253,3 +262,30 @@ func (r *realisClient) thriftCallWithRetries(
return nil, newRetryError(errors.New("ran out of retries"), curStep)
}
// isConnectionError processes the error received by the client.
// The return values indicate whether this was determined to be a temporary error
// and whether it was determined to be a timeout error
func isConnectionError(err error) (bool, bool) {
// Determine if error is a temporary URL error by going up the stack
transportException, ok := err.(thrift.TTransportException)
if !ok {
return false, false
}
urlError, ok := transportException.Err().(*url.Error)
if !ok {
return false, false
}
// EOF error occurs when the server closes the read buffer of the client. This is common
// when the server is overloaded and we consider it temporary.
// All other which are not temporary as per the member function Temporary(),
// are considered not temporary (permanent).
if urlError.Err != io.EOF && !urlError.Temporary() {
return false, false
}
return true, urlError.Timeout()
}

View file

@ -1,13 +0,0 @@
#!/bin/bash
docker-compose up -d
# If running docker-compose up gives any error, don't do anything.
if [ $? -ne 0 ]; then
exit
fi
# Since we run our docker compose setup in bridge mode to be able to run on MacOS, we have to launch a Docker container within the bridge network in order to avoid any routing issues.
docker run --rm -t -v $(pwd):/go/src/github.com/paypal/gorealis --network gorealis_aurora_cluster golang:1.10-stretch go test -v github.com/paypal/gorealis $@
docker-compose down

View file

@ -1,4 +1,4 @@
#!/bin/bash
# Since we run our docker compose setup in bridge mode to be able to run on MacOS, we have to launch a Docker container within the bridge network in order to avoid any routing issues.
docker run --rm -t -v $(pwd):/go/src/github.com/paypal/gorealis --network gorealis_aurora_cluster golang:1.10-stretch go test -v github.com/paypal/gorealis $@
docker run --rm -t -w /gorealis -v $GOPATH/pkg:/go/pkg -v $(pwd):/gorealis --network gorealis_aurora_cluster golang:1.17-buster go test -v github.com/aurora-scheduler/gorealis/v2 $@

465
task.go Normal file
View file

@ -0,0 +1,465 @@
/**
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package realis
import (
"encoding/json"
"strconv"
"github.com/apache/thrift/lib/go/thrift"
"github.com/aurora-scheduler/gorealis/v2/gen-go/apache/aurora"
)
type ResourceType int
const (
CPU ResourceType = iota
RAM
DISK
GPU
)
const (
dedicated = "dedicated"
portPrefix = "org.apache.aurora.port."
)
type AuroraTask struct {
task *aurora.TaskConfig
resources map[ResourceType]*aurora.Resource
portCount int
thermos *ThermosExecutor
}
func NewTask() *AuroraTask {
numCpus := &aurora.Resource{}
ramMb := &aurora.Resource{}
diskMb := &aurora.Resource{}
numCpus.NumCpus = new(float64)
ramMb.RamMb = new(int64)
diskMb.DiskMb = new(int64)
resources := map[ResourceType]*aurora.Resource{CPU: numCpus, RAM: ramMb, DISK: diskMb}
return &AuroraTask{task: &aurora.TaskConfig{
Job: &aurora.JobKey{},
MesosFetcherUris: make([]*aurora.MesosFetcherURI, 0),
Metadata: make([]*aurora.Metadata, 0),
Constraints: make([]*aurora.Constraint, 0),
// Container is a Union so one container field must be set. Set Mesos by default.
Container: NewMesosContainer().Build(),
Resources: []*aurora.Resource{numCpus, ramMb, diskMb},
},
resources: resources,
portCount: 0}
}
// Helper method to convert aurora.TaskConfig to gorealis AuroraTask type
func TaskFromThrift(config *aurora.TaskConfig) *AuroraTask {
newTask := NewTask()
// Pass values using receivers as much as possible
newTask.
Environment(config.Job.Environment).
Role(config.Job.Role).
Name(config.Job.Name).
MaxFailure(config.MaxTaskFailures).
IsService(config.IsService).
Priority(config.Priority)
if config.Tier != nil {
newTask.Tier(*config.Tier)
}
if config.Production != nil {
newTask.Production(*config.Production)
}
if config.ExecutorConfig != nil {
newTask.
ExecutorName(config.ExecutorConfig.Name).
ExecutorData(config.ExecutorConfig.Data)
}
if config.PartitionPolicy != nil {
newTask.PartitionPolicy(
aurora.PartitionPolicy{
Reschedule: config.PartitionPolicy.Reschedule,
DelaySecs: thrift.Int64Ptr(*config.PartitionPolicy.DelaySecs),
})
}
// Make a deep copy of the task's container
if config.Container != nil {
if config.Container.Mesos != nil {
mesosContainer := NewMesosContainer()
if config.Container.Mesos.Image != nil {
if config.Container.Mesos.Image.Appc != nil {
mesosContainer.AppcImage(config.Container.Mesos.Image.Appc.Name, config.Container.Mesos.Image.Appc.ImageId)
} else if config.Container.Mesos.Image.Docker != nil {
mesosContainer.DockerImage(config.Container.Mesos.Image.Docker.Name, config.Container.Mesos.Image.Docker.Tag)
}
}
for _, vol := range config.Container.Mesos.Volumes {
mesosContainer.AddVolume(vol.ContainerPath, vol.HostPath, vol.Mode)
}
newTask.Container(mesosContainer)
} else if config.Container.Docker != nil {
dockerContainer := NewDockerContainer()
dockerContainer.Image(config.Container.Docker.Image)
for _, param := range config.Container.Docker.Parameters {
dockerContainer.AddParameter(param.Name, param.Value)
}
newTask.Container(dockerContainer)
}
}
// Copy all ports
for _, resource := range config.Resources {
// Copy only ports. Set CPU, RAM, DISK, and GPU
if resource != nil {
if resource.NamedPort != nil {
newTask.task.Resources = append(
newTask.task.Resources,
&aurora.Resource{NamedPort: thrift.StringPtr(*resource.NamedPort)},
)
newTask.portCount++
}
if resource.RamMb != nil {
newTask.RAM(*resource.RamMb)
}
if resource.NumCpus != nil {
newTask.CPU(*resource.NumCpus)
}
if resource.DiskMb != nil {
newTask.Disk(*resource.DiskMb)
}
if resource.NumGpus != nil {
newTask.GPU(*resource.NumGpus)
}
}
}
// Copy constraints
for _, constraint := range config.Constraints {
if constraint != nil && constraint.Constraint != nil {
newConstraint := aurora.Constraint{Name: constraint.Name}
taskConstraint := constraint.Constraint
if taskConstraint.Limit != nil {
newConstraint.Constraint = &aurora.TaskConstraint{
Limit: &aurora.LimitConstraint{Limit: taskConstraint.Limit.Limit},
}
newTask.task.Constraints = append(newTask.task.Constraints, &newConstraint)
} else if taskConstraint.Value != nil {
values := make([]string, 0)
for _, val := range taskConstraint.Value.Values {
values = append(values, val)
}
newConstraint.Constraint = &aurora.TaskConstraint{
Value: &aurora.ValueConstraint{Negated: taskConstraint.Value.Negated, Values: values}}
newTask.task.Constraints = append(newTask.task.Constraints, &newConstraint)
}
}
}
// Copy labels
for _, label := range config.Metadata {
newTask.task.Metadata = append(newTask.task.Metadata, &aurora.Metadata{Key: label.Key, Value: label.Value})
}
// Copy Mesos fetcher URIs
for _, uri := range config.MesosFetcherUris {
newTask.task.MesosFetcherUris = append(
newTask.task.MesosFetcherUris,
&aurora.MesosFetcherURI{
Value: uri.Value,
Extract: thrift.BoolPtr(*uri.Extract),
Cache: thrift.BoolPtr(*uri.Cache),
})
}
return newTask
}
// Set AuroraTask Key environment.
func (t *AuroraTask) Environment(env string) *AuroraTask {
t.task.Job.Environment = env
return t
}
// Set AuroraTask Key Role.
func (t *AuroraTask) Role(role string) *AuroraTask {
t.task.Job.Role = role
return t
}
// Set AuroraTask Key Name.
func (t *AuroraTask) Name(name string) *AuroraTask {
t.task.Job.Name = name
return t
}
// Set name of the executor that will the task will be configured to.
func (t *AuroraTask) ExecutorName(name string) *AuroraTask {
if t.task.ExecutorConfig == nil {
t.task.ExecutorConfig = aurora.NewExecutorConfig()
}
t.task.ExecutorConfig.Name = name
return t
}
// Will be included as part of entire task inside the scheduler that will be serialized.
func (t *AuroraTask) ExecutorData(data string) *AuroraTask {
if t.task.ExecutorConfig == nil {
t.task.ExecutorConfig = aurora.NewExecutorConfig()
}
t.task.ExecutorConfig.Data = data
return t
}
func (t *AuroraTask) CPU(cpus float64) *AuroraTask {
*t.resources[CPU].NumCpus = cpus
return t
}
func (t *AuroraTask) RAM(ram int64) *AuroraTask {
*t.resources[RAM].RamMb = ram
return t
}
func (t *AuroraTask) Disk(disk int64) *AuroraTask {
*t.resources[DISK].DiskMb = disk
return t
}
func (t *AuroraTask) GPU(gpu int64) *AuroraTask {
// GPU resource must be set explicitly since the scheduler by default
// rejects jobs with GPU resources attached to it.
if _, ok := t.resources[GPU]; !ok {
t.resources[GPU] = &aurora.Resource{}
t.task.Resources = append(t.task.Resources, t.resources[GPU])
}
t.resources[GPU].NumGpus = &gpu
return t
}
func (t *AuroraTask) Tier(tier string) *AuroraTask {
t.task.Tier = &tier
return t
}
// How many failures to tolerate before giving up.
func (t *AuroraTask) MaxFailure(maxFail int32) *AuroraTask {
t.task.MaxTaskFailures = maxFail
return t
}
// Restart the job's tasks if they fail
func (t *AuroraTask) IsService(isService bool) *AuroraTask {
t.task.IsService = isService
return t
}
//set priority for preemption or priority-queueing
func (t *AuroraTask) Priority(priority int32) *AuroraTask {
t.task.Priority = priority
return t
}
func (t *AuroraTask) Production(production bool) *AuroraTask {
t.task.Production = &production
return t
}
// Add a list of URIs with the same extract and cache configuration. Scheduler must have
// --enable_mesos_fetcher flag enabled. Currently there is no duplicate detection.
func (t *AuroraTask) AddURIs(extract bool, cache bool, values ...string) *AuroraTask {
for _, value := range values {
t.task.MesosFetcherUris = append(
t.task.MesosFetcherUris,
&aurora.MesosFetcherURI{Value: value, Extract: &extract, Cache: &cache})
}
return t
}
// Adds a Mesos label to the job. Note that Aurora will add the
// prefix "org.apache.aurora.metadata." to the beginning of each key.
func (t *AuroraTask) AddLabel(key string, value string) *AuroraTask {
t.task.Metadata = append(t.task.Metadata, &aurora.Metadata{Key: key, Value: value})
return t
}
// Add a named port to the job configuration These are random ports as it's
// not currently possible to request specific ports using Aurora.
func (t *AuroraTask) AddNamedPorts(names ...string) *AuroraTask {
t.portCount += len(names)
for _, name := range names {
t.task.Resources = append(t.task.Resources, &aurora.Resource{NamedPort: &name})
}
return t
}
// Adds a request for a number of ports to the job configuration. The names chosen for these ports
// will be org.apache.aurora.port.X, where X is the current port count for the job configuration
// starting at 0. These are random ports as it's not currently possible to request
// specific ports using Aurora.
func (t *AuroraTask) AddPorts(num int) *AuroraTask {
start := t.portCount
t.portCount += num
for i := start; i < t.portCount; i++ {
portName := portPrefix + strconv.Itoa(i)
t.task.Resources = append(t.task.Resources, &aurora.Resource{NamedPort: &portName})
}
return t
}
// From Aurora Docs:
// Add a Value constraint
// name - Mesos slave attribute that the constraint is matched against.
// If negated = true , treat this as a 'not' - to avoid specific values.
// Values - list of values we look for in attribute name
func (t *AuroraTask) AddValueConstraint(name string, negated bool, values ...string) *AuroraTask {
t.task.Constraints = append(t.task.Constraints,
&aurora.Constraint{
Name: name,
Constraint: &aurora.TaskConstraint{
Value: &aurora.ValueConstraint{
Negated: negated,
Values: values,
},
Limit: nil,
},
})
return t
}
// From Aurora Docs:
// A constraint that specifies the maximum number of active tasks on a host with
// a matching attribute that may be scheduled simultaneously.
func (t *AuroraTask) AddLimitConstraint(name string, limit int32) *AuroraTask {
t.task.Constraints = append(t.task.Constraints,
&aurora.Constraint{
Name: name,
Constraint: &aurora.TaskConstraint{
Value: nil,
Limit: &aurora.LimitConstraint{Limit: limit},
},
})
return t
}
// From Aurora Docs:
// dedicated attribute. Aurora treats this specially, and only allows matching jobs
// to run on these machines, and will only schedule matching jobs on these machines.
// When a job is created, the scheduler requires that the $role component matches
// the role field in the job configuration, and will reject the job creation otherwise.
// A wildcard (*) may be used for the role portion of the dedicated attribute, which
// will allow any owner to elect for a job to run on the host(s)
func (t *AuroraTask) AddDedicatedConstraint(role, name string) *AuroraTask {
t.AddValueConstraint(dedicated, false, role+"/"+name)
return t
}
// Set a container to run for the job configuration to run.
func (t *AuroraTask) Container(container Container) *AuroraTask {
t.task.Container = container.Build()
return t
}
func (t *AuroraTask) TaskConfig() *aurora.TaskConfig {
return t.task
}
func (t *AuroraTask) JobKey() aurora.JobKey {
return *t.task.Job
}
func (t *AuroraTask) Clone() *AuroraTask {
newTask := TaskFromThrift(t.task)
if t.thermos != nil {
newTask.ThermosExecutor(*t.thermos.Clone())
}
return newTask
}
func (t *AuroraTask) ThermosExecutor(thermos ThermosExecutor) *AuroraTask {
t.thermos = &thermos
return t
}
func (t *AuroraTask) BuildThermosPayload() error {
if t.thermos != nil {
// Set the correct resources
if t.resources[CPU].NumCpus != nil {
t.thermos.cpu(*t.resources[CPU].NumCpus)
}
if t.resources[RAM].RamMb != nil {
t.thermos.ram(*t.resources[RAM].RamMb)
}
if t.resources[DISK].DiskMb != nil {
t.thermos.disk(*t.resources[DISK].DiskMb)
}
if t.resources[GPU] != nil && t.resources[GPU].NumGpus != nil {
t.thermos.gpu(*t.resources[GPU].NumGpus)
}
payload, err := json.Marshal(t.thermos)
if err != nil {
return err
}
t.ExecutorName(aurora.AURORA_EXECUTOR_NAME)
t.ExecutorData(string(payload))
}
return nil
}
// Set a partition policy for the job configuration to implement.
func (t *AuroraTask) PartitionPolicy(policy aurora.PartitionPolicy) *AuroraTask {
t.task.PartitionPolicy = &policy
return t
}

59
task_test.go Normal file
View file

@ -0,0 +1,59 @@
/**
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package realis_test
import (
"testing"
realis "github.com/aurora-scheduler/gorealis/v2"
"github.com/aurora-scheduler/gorealis/v2/gen-go/apache/aurora"
"github.com/stretchr/testify/assert"
)
func TestAuroraTask_Clone(t *testing.T) {
task0 := realis.NewTask().
Environment("development").
Role("ubuntu").
Name("this_is_a_test").
ExecutorName(aurora.AURORA_EXECUTOR_NAME).
ExecutorData("{fake:payload}").
CPU(10).
RAM(643).
Disk(1000).
IsService(true).
Priority(1).
Production(false).
AddPorts(10).
Tier("preferred").
MaxFailure(23).
AddURIs(true, true, "testURI").
AddLabel("Test", "Value").
AddNamedPorts("test").
AddValueConstraint("test", false, "testing").
AddLimitConstraint("test_limit", 1).
AddDedicatedConstraint("ubuntu", "name").
Container(realis.NewDockerContainer().AddParameter("hello", "world").Image("testImg"))
task1 := task0.Clone()
assert.EqualValues(t, task0, task1, "Clone does not return the correct deep copy of AuroraTask")
task0.Container(realis.NewMesosContainer().
AppcImage("test", "testing").
AddVolume("test", "test", aurora.Mode_RW))
task2 := task0.Clone()
assert.EqualValues(t, task0, task2, "Clone does not return the correct deep copy of AuroraTask")
}

195
thermos.go Normal file
View file

@ -0,0 +1,195 @@
/**
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package realis
import "encoding/json"
type ThermosExecutor struct {
Task ThermosTask `json:"task""`
order *ThermosConstraint `json:"-"`
}
type ThermosTask struct {
Processes map[string]*ThermosProcess `json:"processes"`
Constraints []*ThermosConstraint `json:"constraints"`
Resources thermosResources `json:"resources"`
}
type ThermosConstraint struct {
Order []string `json:"order,omitempty"`
}
// This struct should always be controlled by the Aurora job struct.
// Therefore it is private.
type thermosResources struct {
CPU *float64 `json:"cpu,omitempty"`
Disk *int64 `json:"disk,omitempty"`
RAM *int64 `json:"ram,omitempty"`
GPU *int64 `json:"gpu,omitempty"`
}
type ThermosProcess struct {
Name string `json:"name"`
Cmdline string `json:"cmdline"`
Daemon bool `json:"daemon"`
Ephemeral bool `json:"ephemeral"`
MaxFailures int `json:"max_failures"`
MinDuration int `json:"min_duration"`
Final bool `json:"final"`
}
func NewThermosProcess(name, command string) ThermosProcess {
return ThermosProcess{
Name: name,
Cmdline: command,
MaxFailures: 1,
Daemon: false,
Ephemeral: false,
MinDuration: 5,
Final: false}
}
// Processes must have unique names. Adding a process whose name already exists will
// result in overwriting the previous version of the process.
func (t *ThermosExecutor) AddProcess(process ThermosProcess) *ThermosExecutor {
if len(t.Task.Processes) == 0 {
t.Task.Processes = make(map[string]*ThermosProcess, 0)
}
t.Task.Processes[process.Name] = &process
// Add Process to order
t.addToOrder(process.Name)
return t
}
// Only constraint that should be added for now is the order of execution, therefore this
// receiver is private.
func (t *ThermosExecutor) addConstraint(constraint *ThermosConstraint) *ThermosExecutor {
if len(t.Task.Constraints) == 0 {
t.Task.Constraints = make([]*ThermosConstraint, 0)
}
t.Task.Constraints = append(t.Task.Constraints, constraint)
return t
}
// Order in which the Processes should be executed. Index 0 will be executed first, index N will be executed last.
func (t *ThermosExecutor) ProcessOrder(order ...string) *ThermosExecutor {
if t.order == nil {
t.order = &ThermosConstraint{}
t.addConstraint(t.order)
}
t.order.Order = order
return t
}
// Add Process to execution order. By default this is a FIFO setup. Custom order can be given by overriding
// with ProcessOrder
func (t *ThermosExecutor) addToOrder(name string) {
if t.order == nil {
t.order = &ThermosConstraint{Order: make([]string, 0)}
t.addConstraint(t.order)
}
t.order.Order = append(t.order.Order, name)
}
// Ram is determined by the job object.
func (t *ThermosExecutor) ram(ram int64) {
// Convert from bytes to MiB
ram *= 1024 ^ 2
t.Task.Resources.RAM = &ram
}
// Disk is determined by the job object.
func (t *ThermosExecutor) disk(disk int64) {
// Convert from bytes to MiB
disk *= 1024 ^ 2
t.Task.Resources.Disk = &disk
}
// CPU is determined by the job object.
func (t *ThermosExecutor) cpu(cpu float64) {
t.Task.Resources.CPU = &cpu
}
// GPU is determined by the job object.
func (t *ThermosExecutor) gpu(gpu int64) {
t.Task.Resources.GPU = &gpu
}
// Deep copy of Thermos executor
func (t *ThermosExecutor) Clone() *ThermosExecutor {
tNew := ThermosExecutor{}
if t.order != nil {
tNew.order = &ThermosConstraint{Order: t.order.Order}
tNew.addConstraint(tNew.order)
}
tNew.Task.Processes = make(map[string]*ThermosProcess)
for name, process := range t.Task.Processes {
newProcess := *process
tNew.Task.Processes[name] = &newProcess
}
tNew.Task.Resources = t.Task.Resources
return &tNew
}
type thermosTaskJSON struct {
Processes []*ThermosProcess `json:"processes"`
Constraints []*ThermosConstraint `json:"constraints"`
Resources thermosResources `json:"resources"`
}
// Custom Marshaling for Thermos Task to match what Thermos expects
func (t *ThermosTask) MarshalJSON() ([]byte, error) {
// Convert map to array to match what Thermos expects
processes := make([]*ThermosProcess, 0)
for _, process := range t.Processes {
processes = append(processes, process)
}
return json.Marshal(&thermosTaskJSON{
Processes: processes,
Constraints: t.Constraints,
Resources: t.Resources,
})
}
// Custom Unmarshaling to match what Thermos would contain
func (t *ThermosTask) UnmarshalJSON(data []byte) error {
// Thermos format
aux := &thermosTaskJSON{}
if err := json.Unmarshal(data, &aux); err != nil {
return err
}
processes := make(map[string]*ThermosProcess)
for _, process := range aux.Processes {
processes[process.Name] = process
}
return nil
}

71
thermos_test.go Normal file
View file

@ -0,0 +1,71 @@
package realis
import (
"encoding/json"
"testing"
"github.com/apache/thrift/lib/go/thrift"
"github.com/stretchr/testify/assert"
)
func TestThermosTask(t *testing.T) {
// Test that we can successfully deserialize a minimum subset of an Aurora generated thermos payload
thermosJSON := []byte(
`{
"task": {
"processes": [
{
"daemon": false,
"name": "hello",
"ephemeral": false,
"max_failures": 1,
"min_duration": 5,
"cmdline": "\n while true; do\n echo hello world from gorealis\n sleep 10\n done\n ",
"final": false
}
],
"resources": {
"gpu": 0,
"disk": 134217728,
"ram": 134217728,
"cpu": 1.1
},
"constraints": [
{
"order": [
"hello"
]
}
]
}
}`)
thermos := ThermosExecutor{}
err := json.Unmarshal(thermosJSON, &thermos)
assert.NoError(t, err)
process := &ThermosProcess{
Daemon: false,
Name: "hello",
Ephemeral: false,
MaxFailures: 1,
MinDuration: 5,
Cmdline: "\n while true; do\n echo hello world from gorealis\n sleep 10\n done\n ",
Final: false,
}
constraint := &ThermosConstraint{Order: []string{process.Name}}
thermosExpected := ThermosExecutor{
Task: ThermosTask{
Processes: map[string]*ThermosProcess{process.Name: process},
Constraints: []*ThermosConstraint{constraint},
Resources: thermosResources{CPU: thrift.Float64Ptr(1.1),
Disk: thrift.Int64Ptr(134217728),
RAM: thrift.Int64Ptr(134217728),
GPU: thrift.Int64Ptr(0)}}}
assert.ObjectsAreEqualValues(thermosExpected, thermos)
}

View file

@ -1,188 +0,0 @@
/**
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package realis
import (
"github.com/paypal/gorealis/gen-go/apache/aurora"
)
// UpdateJob is a structure to collect all information required to create job update.
type UpdateJob struct {
Job // SetInstanceCount for job is hidden, access via full qualifier
req *aurora.JobUpdateRequest
}
// NewDefaultUpdateJob creates an UpdateJob object with opinionated default settings.
func NewDefaultUpdateJob(config *aurora.TaskConfig) *UpdateJob {
req := aurora.NewJobUpdateRequest()
req.TaskConfig = config
req.Settings = NewUpdateSettings()
job, ok := NewJob().(*AuroraJob)
if !ok {
// This should never happen but it is here as a safeguard
return nil
}
job.jobConfig.TaskConfig = config
// Rebuild resource map from TaskConfig
for _, ptr := range config.Resources {
if ptr.NumCpus != nil {
job.resources[CPU].NumCpus = ptr.NumCpus
continue // Guard against Union violations that Go won't enforce
}
if ptr.RamMb != nil {
job.resources[RAM].RamMb = ptr.RamMb
continue
}
if ptr.DiskMb != nil {
job.resources[DISK].DiskMb = ptr.DiskMb
continue
}
if ptr.NumGpus != nil {
job.resources[GPU] = &aurora.Resource{NumGpus: ptr.NumGpus}
continue
}
}
// Mirrors defaults set by Pystachio
req.Settings.UpdateGroupSize = 1
req.Settings.WaitForBatchCompletion = false
req.Settings.MinWaitInInstanceRunningMs = 45000
req.Settings.MaxPerInstanceFailures = 0
req.Settings.MaxFailedInstances = 0
req.Settings.RollbackOnFailure = true
//TODO(rdelvalle): Deep copy job struct to avoid unexpected behavior
return &UpdateJob{Job: job, req: req}
}
// NewUpdateJob creates an UpdateJob object wihtout default settings.
func NewUpdateJob(config *aurora.TaskConfig, settings *aurora.JobUpdateSettings) *UpdateJob {
req := aurora.NewJobUpdateRequest()
req.TaskConfig = config
req.Settings = settings
job, ok := NewJob().(*AuroraJob)
if !ok {
// This should never happen but it is here as a safeguard
return nil
}
job.jobConfig.TaskConfig = config
// Rebuild resource map from TaskConfig
for _, ptr := range config.Resources {
if ptr.NumCpus != nil {
job.resources[CPU].NumCpus = ptr.NumCpus
continue // Guard against Union violations that Go won't enforce
}
if ptr.RamMb != nil {
job.resources[RAM].RamMb = ptr.RamMb
continue
}
if ptr.DiskMb != nil {
job.resources[DISK].DiskMb = ptr.DiskMb
continue
}
if ptr.NumGpus != nil {
job.resources[GPU] = &aurora.Resource{}
job.resources[GPU].NumGpus = ptr.NumGpus
continue // Guard against Union violations that Go won't enforce
}
}
//TODO(rdelvalle): Deep copy job struct to avoid unexpected behavior
return &UpdateJob{Job: job, req: req}
}
// InstanceCount sets instance count the job will have after the update.
func (u *UpdateJob) InstanceCount(inst int32) *UpdateJob {
u.req.InstanceCount = inst
return u
}
// BatchSize sets the max number of instances being updated at any given moment.
func (u *UpdateJob) BatchSize(size int32) *UpdateJob {
u.req.Settings.UpdateGroupSize = size
return u
}
// WatchTime sets the minimum number of seconds a shard must remain in RUNNING state before considered a success.
func (u *UpdateJob) WatchTime(ms int32) *UpdateJob {
u.req.Settings.MinWaitInInstanceRunningMs = ms
return u
}
// WaitForBatchCompletion configures the job update to wait for all instances in a group to be done before moving on.
func (u *UpdateJob) WaitForBatchCompletion(batchWait bool) *UpdateJob {
u.req.Settings.WaitForBatchCompletion = batchWait
return u
}
// MaxPerInstanceFailures sets the max number of instance failures to tolerate before marking instance as FAILED.
func (u *UpdateJob) MaxPerInstanceFailures(inst int32) *UpdateJob {
u.req.Settings.MaxPerInstanceFailures = inst
return u
}
// MaxFailedInstances sets the max number of FAILED instances to tolerate before terminating the update.
func (u *UpdateJob) MaxFailedInstances(inst int32) *UpdateJob {
u.req.Settings.MaxFailedInstances = inst
return u
}
// RollbackOnFail configure the job to rollback automatically after a job update fails.
func (u *UpdateJob) RollbackOnFail(rollback bool) *UpdateJob {
u.req.Settings.RollbackOnFailure = rollback
return u
}
// NewUpdateSettings return an opinionated set of job update settings.
func (u *UpdateJob) BatchUpdateStrategy(strategy aurora.BatchJobUpdateStrategy) *UpdateJob {
u.req.Settings.UpdateStrategy = &aurora.JobUpdateStrategy{BatchStrategy: &strategy}
return u
}
func (u *UpdateJob) QueueUpdateStrategy(strategy aurora.QueueJobUpdateStrategy) *UpdateJob {
u.req.Settings.UpdateStrategy = &aurora.JobUpdateStrategy{QueueStrategy: &strategy}
return u
}
func (u *UpdateJob) VariableBatchStrategy(strategy aurora.VariableBatchJobUpdateStrategy) *UpdateJob {
u.req.Settings.UpdateStrategy = &aurora.JobUpdateStrategy{VarBatchStrategy: &strategy}
return u
}
func NewUpdateSettings() *aurora.JobUpdateSettings {
us := new(aurora.JobUpdateSettings)
// Mirrors defaults set by Pystachio
us.UpdateGroupSize = 1
us.WaitForBatchCompletion = false
us.MinWaitInInstanceRunningMs = 45000
us.MaxPerInstanceFailures = 0
us.MaxFailedInstances = 0
us.RollbackOnFailure = true
return us
}

70
util.go
View file

@ -4,40 +4,15 @@ import (
"net/url"
"strings"
"github.com/paypal/gorealis/gen-go/apache/aurora"
"github.com/aurora-scheduler/gorealis/v2/gen-go/apache/aurora"
"github.com/pkg/errors"
)
const apiPath = "/api"
// ActiveStates - States a task may be in when active.
var ActiveStates = make(map[aurora.ScheduleStatus]bool)
// SlaveAssignedStates - States a task may be in when it has already been assigned to a Mesos agent.
var SlaveAssignedStates = make(map[aurora.ScheduleStatus]bool)
// LiveStates - States a task may be in when it is live (e.g. able to take traffic)
var LiveStates = make(map[aurora.ScheduleStatus]bool)
// TerminalStates - Set of states a task may not transition away from.
var TerminalStates = make(map[aurora.ScheduleStatus]bool)
// ActiveJobUpdateStates - States a Job Update may be in where it is considered active.
var ActiveJobUpdateStates = make(map[aurora.JobUpdateStatus]bool)
// TerminalJobUpdateStates returns a slice containing all the terminal states an update may end up in.
// This is a function in order to avoid having a slice that can be accidentally mutated.
func TerminalUpdateStates() []aurora.JobUpdateStatus {
return []aurora.JobUpdateStatus{
aurora.JobUpdateStatus_ROLLED_FORWARD,
aurora.JobUpdateStatus_ROLLED_BACK,
aurora.JobUpdateStatus_ABORTED,
aurora.JobUpdateStatus_ERROR,
aurora.JobUpdateStatus_FAILED,
}
}
// AwaitingPulseJobUpdateStates - States a job update may be in where it is waiting for a pulse.
var AwaitingPulseJobUpdateStates = make(map[aurora.JobUpdateStatus]bool)
func init() {
@ -65,14 +40,26 @@ func init() {
}
}
func validateAuroraURL(location string) (string, error) {
// TerminalUpdateStates returns a slice containing all the terminal states an update may be in.
// This is a function in order to avoid having a slice that can be accidentally mutated.
func TerminalUpdateStates() []aurora.JobUpdateStatus {
return []aurora.JobUpdateStatus{
aurora.JobUpdateStatus_ROLLED_FORWARD,
aurora.JobUpdateStatus_ROLLED_BACK,
aurora.JobUpdateStatus_ABORTED,
aurora.JobUpdateStatus_ERROR,
aurora.JobUpdateStatus_FAILED,
}
}
func validateAuroraAddress(address string) (string, error) {
// If no protocol defined, assume http
if !strings.Contains(location, "://") {
location = "http://" + location
if !strings.Contains(address, "://") {
address = "http://" + address
}
u, err := url.Parse(location)
u, err := url.Parse(address)
if err != nil {
return "", errors.Wrap(err, "error parsing url")
@ -92,8 +79,7 @@ func validateAuroraURL(location string) (string, error) {
return "", errors.Errorf("only protocols http and https are supported %v\n", u.Scheme)
}
// This could theoretically be elsewhwere but we'll be strict for the sake of simplicty
if u.Path != apiPath {
if u.Path != "/api" {
return "", errors.Errorf("expected /api path %v\n", u.Path)
}
@ -118,3 +104,23 @@ func calculateCurrentBatch(updatingInstances int32, batchSizes []int32) int {
}
return batchCount
}
func ResourcesToMap(resources []*aurora.Resource) map[string]float64 {
result := map[string]float64{}
for _, resource := range resources {
if resource.NumCpus != nil {
result["cpus"] += *resource.NumCpus
} else if resource.RamMb != nil {
result["mem"] += float64(*resource.RamMb)
} else if resource.DiskMb != nil {
result["disk"] += float64(*resource.DiskMb)
} else if resource.NamedPort != nil {
result["ports"]++
} else if resource.NumGpus != nil {
result["gpus"] += float64(*resource.NumGpus)
}
}
return result
}

View file

@ -20,50 +20,6 @@ import (
"github.com/stretchr/testify/assert"
)
func TestAuroraURLValidator(t *testing.T) {
t.Run("badURL", func(t *testing.T) {
url, err := validateAuroraURL("http://badurl.com/badpath")
assert.Empty(t, url)
assert.Error(t, err)
})
t.Run("URLHttp", func(t *testing.T) {
url, err := validateAuroraURL("http://goodurl.com:8081/api")
assert.Equal(t, "http://goodurl.com:8081/api", url)
assert.NoError(t, err)
})
t.Run("URLHttps", func(t *testing.T) {
url, err := validateAuroraURL("https://goodurl.com:8081/api")
assert.Equal(t, "https://goodurl.com:8081/api", url)
assert.NoError(t, err)
})
t.Run("URLNoPath", func(t *testing.T) {
url, err := validateAuroraURL("http://goodurl.com:8081")
assert.Equal(t, "http://goodurl.com:8081/api", url)
assert.NoError(t, err)
})
t.Run("ipAddrNoPath", func(t *testing.T) {
url, err := validateAuroraURL("http://192.168.1.33:8081")
assert.Equal(t, "http://192.168.1.33:8081/api", url)
assert.NoError(t, err)
})
t.Run("URLNoProtocol", func(t *testing.T) {
url, err := validateAuroraURL("goodurl.com:8081/api")
assert.Equal(t, "http://goodurl.com:8081/api", url)
assert.NoError(t, err)
})
t.Run("URLNoProtocolNoPathNoPort", func(t *testing.T) {
url, err := validateAuroraURL("goodurl.com")
assert.Equal(t, "http://goodurl.com:8081/api", url)
assert.NoError(t, err)
})
}
func TestCurrentBatchCalculator(t *testing.T) {
t.Run("singleBatchOverflow", func(t *testing.T) {
curBatch := calculateCurrentBatch(10, []int32{2})

View file

@ -1,164 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
const (
UNKNOWN_APPLICATION_EXCEPTION = 0
UNKNOWN_METHOD = 1
INVALID_MESSAGE_TYPE_EXCEPTION = 2
WRONG_METHOD_NAME = 3
BAD_SEQUENCE_ID = 4
MISSING_RESULT = 5
INTERNAL_ERROR = 6
PROTOCOL_ERROR = 7
)
var defaultApplicationExceptionMessage = map[int32]string{
UNKNOWN_APPLICATION_EXCEPTION: "unknown application exception",
UNKNOWN_METHOD: "unknown method",
INVALID_MESSAGE_TYPE_EXCEPTION: "invalid message type",
WRONG_METHOD_NAME: "wrong method name",
BAD_SEQUENCE_ID: "bad sequence ID",
MISSING_RESULT: "missing result",
INTERNAL_ERROR: "unknown internal error",
PROTOCOL_ERROR: "unknown protocol error",
}
// Application level Thrift exception
type TApplicationException interface {
TException
TypeId() int32
Read(iprot TProtocol) error
Write(oprot TProtocol) error
}
type tApplicationException struct {
message string
type_ int32
}
func (e tApplicationException) Error() string {
if e.message != "" {
return e.message
}
return defaultApplicationExceptionMessage[e.type_]
}
func NewTApplicationException(type_ int32, message string) TApplicationException {
return &tApplicationException{message, type_}
}
func (p *tApplicationException) TypeId() int32 {
return p.type_
}
func (p *tApplicationException) Read(iprot TProtocol) error {
// TODO: this should really be generated by the compiler
_, err := iprot.ReadStructBegin()
if err != nil {
return err
}
message := ""
type_ := int32(UNKNOWN_APPLICATION_EXCEPTION)
for {
_, ttype, id, err := iprot.ReadFieldBegin()
if err != nil {
return err
}
if ttype == STOP {
break
}
switch id {
case 1:
if ttype == STRING {
if message, err = iprot.ReadString(); err != nil {
return err
}
} else {
if err = SkipDefaultDepth(iprot, ttype); err != nil {
return err
}
}
case 2:
if ttype == I32 {
if type_, err = iprot.ReadI32(); err != nil {
return err
}
} else {
if err = SkipDefaultDepth(iprot, ttype); err != nil {
return err
}
}
default:
if err = SkipDefaultDepth(iprot, ttype); err != nil {
return err
}
}
if err = iprot.ReadFieldEnd(); err != nil {
return err
}
}
if err := iprot.ReadStructEnd(); err != nil {
return err
}
p.message = message
p.type_ = type_
return nil
}
func (p *tApplicationException) Write(oprot TProtocol) (err error) {
err = oprot.WriteStructBegin("TApplicationException")
if len(p.Error()) > 0 {
err = oprot.WriteFieldBegin("message", STRING, 1)
if err != nil {
return
}
err = oprot.WriteString(p.Error())
if err != nil {
return
}
err = oprot.WriteFieldEnd()
if err != nil {
return
}
}
err = oprot.WriteFieldBegin("type", I32, 2)
if err != nil {
return
}
err = oprot.WriteI32(p.type_)
if err != nil {
return
}
err = oprot.WriteFieldEnd()
if err != nil {
return
}
err = oprot.WriteFieldStop()
if err != nil {
return
}
err = oprot.WriteStructEnd()
return
}

View file

@ -1,41 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"testing"
)
func TestTApplicationException(t *testing.T) {
exc := NewTApplicationException(UNKNOWN_APPLICATION_EXCEPTION, "")
if exc.Error() != defaultApplicationExceptionMessage[UNKNOWN_APPLICATION_EXCEPTION] {
t.Fatalf("Expected empty string for exception but found '%s'", exc.Error())
}
if exc.TypeId() != UNKNOWN_APPLICATION_EXCEPTION {
t.Fatalf("Expected type UNKNOWN for exception but found '%v'", exc.TypeId())
}
exc = NewTApplicationException(WRONG_METHOD_NAME, "junk_method")
if exc.Error() != "junk_method" {
t.Fatalf("Expected 'junk_method' for exception but found '%s'", exc.Error())
}
if exc.TypeId() != WRONG_METHOD_NAME {
t.Fatalf("Expected type WRONG_METHOD_NAME for exception but found '%v'", exc.TypeId())
}
}

View file

@ -1,509 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
"io"
"math"
)
type TBinaryProtocol struct {
trans TRichTransport
origTransport TTransport
reader io.Reader
writer io.Writer
strictRead bool
strictWrite bool
buffer [64]byte
}
type TBinaryProtocolFactory struct {
strictRead bool
strictWrite bool
}
func NewTBinaryProtocolTransport(t TTransport) *TBinaryProtocol {
return NewTBinaryProtocol(t, false, true)
}
func NewTBinaryProtocol(t TTransport, strictRead, strictWrite bool) *TBinaryProtocol {
p := &TBinaryProtocol{origTransport: t, strictRead: strictRead, strictWrite: strictWrite}
if et, ok := t.(TRichTransport); ok {
p.trans = et
} else {
p.trans = NewTRichTransport(t)
}
p.reader = p.trans
p.writer = p.trans
return p
}
func NewTBinaryProtocolFactoryDefault() *TBinaryProtocolFactory {
return NewTBinaryProtocolFactory(false, true)
}
func NewTBinaryProtocolFactory(strictRead, strictWrite bool) *TBinaryProtocolFactory {
return &TBinaryProtocolFactory{strictRead: strictRead, strictWrite: strictWrite}
}
func (p *TBinaryProtocolFactory) GetProtocol(t TTransport) TProtocol {
return NewTBinaryProtocol(t, p.strictRead, p.strictWrite)
}
/**
* Writing Methods
*/
func (p *TBinaryProtocol) WriteMessageBegin(name string, typeId TMessageType, seqId int32) error {
if p.strictWrite {
version := uint32(VERSION_1) | uint32(typeId)
e := p.WriteI32(int32(version))
if e != nil {
return e
}
e = p.WriteString(name)
if e != nil {
return e
}
e = p.WriteI32(seqId)
return e
} else {
e := p.WriteString(name)
if e != nil {
return e
}
e = p.WriteByte(int8(typeId))
if e != nil {
return e
}
e = p.WriteI32(seqId)
return e
}
return nil
}
func (p *TBinaryProtocol) WriteMessageEnd() error {
return nil
}
func (p *TBinaryProtocol) WriteStructBegin(name string) error {
return nil
}
func (p *TBinaryProtocol) WriteStructEnd() error {
return nil
}
func (p *TBinaryProtocol) WriteFieldBegin(name string, typeId TType, id int16) error {
e := p.WriteByte(int8(typeId))
if e != nil {
return e
}
e = p.WriteI16(id)
return e
}
func (p *TBinaryProtocol) WriteFieldEnd() error {
return nil
}
func (p *TBinaryProtocol) WriteFieldStop() error {
e := p.WriteByte(STOP)
return e
}
func (p *TBinaryProtocol) WriteMapBegin(keyType TType, valueType TType, size int) error {
e := p.WriteByte(int8(keyType))
if e != nil {
return e
}
e = p.WriteByte(int8(valueType))
if e != nil {
return e
}
e = p.WriteI32(int32(size))
return e
}
func (p *TBinaryProtocol) WriteMapEnd() error {
return nil
}
func (p *TBinaryProtocol) WriteListBegin(elemType TType, size int) error {
e := p.WriteByte(int8(elemType))
if e != nil {
return e
}
e = p.WriteI32(int32(size))
return e
}
func (p *TBinaryProtocol) WriteListEnd() error {
return nil
}
func (p *TBinaryProtocol) WriteSetBegin(elemType TType, size int) error {
e := p.WriteByte(int8(elemType))
if e != nil {
return e
}
e = p.WriteI32(int32(size))
return e
}
func (p *TBinaryProtocol) WriteSetEnd() error {
return nil
}
func (p *TBinaryProtocol) WriteBool(value bool) error {
if value {
return p.WriteByte(1)
}
return p.WriteByte(0)
}
func (p *TBinaryProtocol) WriteByte(value int8) error {
e := p.trans.WriteByte(byte(value))
return NewTProtocolException(e)
}
func (p *TBinaryProtocol) WriteI16(value int16) error {
v := p.buffer[0:2]
binary.BigEndian.PutUint16(v, uint16(value))
_, e := p.writer.Write(v)
return NewTProtocolException(e)
}
func (p *TBinaryProtocol) WriteI32(value int32) error {
v := p.buffer[0:4]
binary.BigEndian.PutUint32(v, uint32(value))
_, e := p.writer.Write(v)
return NewTProtocolException(e)
}
func (p *TBinaryProtocol) WriteI64(value int64) error {
v := p.buffer[0:8]
binary.BigEndian.PutUint64(v, uint64(value))
_, err := p.writer.Write(v)
return NewTProtocolException(err)
}
func (p *TBinaryProtocol) WriteDouble(value float64) error {
return p.WriteI64(int64(math.Float64bits(value)))
}
func (p *TBinaryProtocol) WriteString(value string) error {
e := p.WriteI32(int32(len(value)))
if e != nil {
return e
}
_, err := p.trans.WriteString(value)
return NewTProtocolException(err)
}
func (p *TBinaryProtocol) WriteBinary(value []byte) error {
e := p.WriteI32(int32(len(value)))
if e != nil {
return e
}
_, err := p.writer.Write(value)
return NewTProtocolException(err)
}
/**
* Reading methods
*/
func (p *TBinaryProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqId int32, err error) {
size, e := p.ReadI32()
if e != nil {
return "", typeId, 0, NewTProtocolException(e)
}
if size < 0 {
typeId = TMessageType(size & 0x0ff)
version := int64(int64(size) & VERSION_MASK)
if version != VERSION_1 {
return name, typeId, seqId, NewTProtocolExceptionWithType(BAD_VERSION, fmt.Errorf("Bad version in ReadMessageBegin"))
}
name, e = p.ReadString()
if e != nil {
return name, typeId, seqId, NewTProtocolException(e)
}
seqId, e = p.ReadI32()
if e != nil {
return name, typeId, seqId, NewTProtocolException(e)
}
return name, typeId, seqId, nil
}
if p.strictRead {
return name, typeId, seqId, NewTProtocolExceptionWithType(BAD_VERSION, fmt.Errorf("Missing version in ReadMessageBegin"))
}
name, e2 := p.readStringBody(size)
if e2 != nil {
return name, typeId, seqId, e2
}
b, e3 := p.ReadByte()
if e3 != nil {
return name, typeId, seqId, e3
}
typeId = TMessageType(b)
seqId, e4 := p.ReadI32()
if e4 != nil {
return name, typeId, seqId, e4
}
return name, typeId, seqId, nil
}
func (p *TBinaryProtocol) ReadMessageEnd() error {
return nil
}
func (p *TBinaryProtocol) ReadStructBegin() (name string, err error) {
return
}
func (p *TBinaryProtocol) ReadStructEnd() error {
return nil
}
func (p *TBinaryProtocol) ReadFieldBegin() (name string, typeId TType, seqId int16, err error) {
t, err := p.ReadByte()
typeId = TType(t)
if err != nil {
return name, typeId, seqId, err
}
if t != STOP {
seqId, err = p.ReadI16()
}
return name, typeId, seqId, err
}
func (p *TBinaryProtocol) ReadFieldEnd() error {
return nil
}
var invalidDataLength = NewTProtocolExceptionWithType(INVALID_DATA, errors.New("Invalid data length"))
func (p *TBinaryProtocol) ReadMapBegin() (kType, vType TType, size int, err error) {
k, e := p.ReadByte()
if e != nil {
err = NewTProtocolException(e)
return
}
kType = TType(k)
v, e := p.ReadByte()
if e != nil {
err = NewTProtocolException(e)
return
}
vType = TType(v)
size32, e := p.ReadI32()
if e != nil {
err = NewTProtocolException(e)
return
}
if size32 < 0 {
err = invalidDataLength
return
}
size = int(size32)
return kType, vType, size, nil
}
func (p *TBinaryProtocol) ReadMapEnd() error {
return nil
}
func (p *TBinaryProtocol) ReadListBegin() (elemType TType, size int, err error) {
b, e := p.ReadByte()
if e != nil {
err = NewTProtocolException(e)
return
}
elemType = TType(b)
size32, e := p.ReadI32()
if e != nil {
err = NewTProtocolException(e)
return
}
if size32 < 0 {
err = invalidDataLength
return
}
size = int(size32)
return
}
func (p *TBinaryProtocol) ReadListEnd() error {
return nil
}
func (p *TBinaryProtocol) ReadSetBegin() (elemType TType, size int, err error) {
b, e := p.ReadByte()
if e != nil {
err = NewTProtocolException(e)
return
}
elemType = TType(b)
size32, e := p.ReadI32()
if e != nil {
err = NewTProtocolException(e)
return
}
if size32 < 0 {
err = invalidDataLength
return
}
size = int(size32)
return elemType, size, nil
}
func (p *TBinaryProtocol) ReadSetEnd() error {
return nil
}
func (p *TBinaryProtocol) ReadBool() (bool, error) {
b, e := p.ReadByte()
v := true
if b != 1 {
v = false
}
return v, e
}
func (p *TBinaryProtocol) ReadByte() (int8, error) {
v, err := p.trans.ReadByte()
return int8(v), err
}
func (p *TBinaryProtocol) ReadI16() (value int16, err error) {
buf := p.buffer[0:2]
err = p.readAll(buf)
value = int16(binary.BigEndian.Uint16(buf))
return value, err
}
func (p *TBinaryProtocol) ReadI32() (value int32, err error) {
buf := p.buffer[0:4]
err = p.readAll(buf)
value = int32(binary.BigEndian.Uint32(buf))
return value, err
}
func (p *TBinaryProtocol) ReadI64() (value int64, err error) {
buf := p.buffer[0:8]
err = p.readAll(buf)
value = int64(binary.BigEndian.Uint64(buf))
return value, err
}
func (p *TBinaryProtocol) ReadDouble() (value float64, err error) {
buf := p.buffer[0:8]
err = p.readAll(buf)
value = math.Float64frombits(binary.BigEndian.Uint64(buf))
return value, err
}
func (p *TBinaryProtocol) ReadString() (value string, err error) {
size, e := p.ReadI32()
if e != nil {
return "", e
}
if size < 0 {
err = invalidDataLength
return
}
return p.readStringBody(size)
}
func (p *TBinaryProtocol) ReadBinary() ([]byte, error) {
size, e := p.ReadI32()
if e != nil {
return nil, e
}
if size < 0 {
return nil, invalidDataLength
}
isize := int(size)
buf := make([]byte, isize)
_, err := io.ReadFull(p.trans, buf)
return buf, NewTProtocolException(err)
}
func (p *TBinaryProtocol) Flush(ctx context.Context) (err error) {
return NewTProtocolException(p.trans.Flush(ctx))
}
func (p *TBinaryProtocol) Skip(fieldType TType) (err error) {
return SkipDefaultDepth(p, fieldType)
}
func (p *TBinaryProtocol) Transport() TTransport {
return p.origTransport
}
func (p *TBinaryProtocol) readAll(buf []byte) error {
_, err := io.ReadFull(p.reader, buf)
return NewTProtocolException(err)
}
const readLimit = 32768
func (p *TBinaryProtocol) readStringBody(size int32) (value string, err error) {
if size < 0 {
return "", nil
}
var (
buf bytes.Buffer
e error
b []byte
)
switch {
case int(size) <= len(p.buffer):
b = p.buffer[:size] // avoids allocation for small reads
case int(size) < readLimit:
b = make([]byte, size)
default:
b = make([]byte, readLimit)
}
for size > 0 {
_, e = io.ReadFull(p.trans, b)
buf.Write(b)
if e != nil {
break
}
size -= readLimit
if size < readLimit && size > 0 {
b = b[:size]
}
}
return buf.String(), NewTProtocolException(e)
}

View file

@ -1,28 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"testing"
)
func TestReadWriteBinaryProtocol(t *testing.T) {
ReadWriteProtocolTest(t, NewTBinaryProtocolFactoryDefault())
}

View file

@ -1,92 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"bufio"
"context"
)
type TBufferedTransportFactory struct {
size int
}
type TBufferedTransport struct {
bufio.ReadWriter
tp TTransport
}
func (p *TBufferedTransportFactory) GetTransport(trans TTransport) (TTransport, error) {
return NewTBufferedTransport(trans, p.size), nil
}
func NewTBufferedTransportFactory(bufferSize int) *TBufferedTransportFactory {
return &TBufferedTransportFactory{size: bufferSize}
}
func NewTBufferedTransport(trans TTransport, bufferSize int) *TBufferedTransport {
return &TBufferedTransport{
ReadWriter: bufio.ReadWriter{
Reader: bufio.NewReaderSize(trans, bufferSize),
Writer: bufio.NewWriterSize(trans, bufferSize),
},
tp: trans,
}
}
func (p *TBufferedTransport) IsOpen() bool {
return p.tp.IsOpen()
}
func (p *TBufferedTransport) Open() (err error) {
return p.tp.Open()
}
func (p *TBufferedTransport) Close() (err error) {
return p.tp.Close()
}
func (p *TBufferedTransport) Read(b []byte) (int, error) {
n, err := p.ReadWriter.Read(b)
if err != nil {
p.ReadWriter.Reader.Reset(p.tp)
}
return n, err
}
func (p *TBufferedTransport) Write(b []byte) (int, error) {
n, err := p.ReadWriter.Write(b)
if err != nil {
p.ReadWriter.Writer.Reset(p.tp)
}
return n, err
}
func (p *TBufferedTransport) Flush(ctx context.Context) error {
if err := p.ReadWriter.Flush(); err != nil {
p.ReadWriter.Writer.Reset(p.tp)
return err
}
return p.tp.Flush(ctx)
}
func (p *TBufferedTransport) RemainingBytes() (num_bytes uint64) {
return p.tp.RemainingBytes()
}

View file

@ -1,29 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"testing"
)
func TestBufferedTransport(t *testing.T) {
trans := NewTBufferedTransport(NewTMemoryBuffer(), 10240)
TransportTest(t, trans, trans)
}

View file

@ -1,85 +0,0 @@
package thrift
import (
"context"
"fmt"
)
type TClient interface {
Call(ctx context.Context, method string, args, result TStruct) error
}
type TStandardClient struct {
seqId int32
iprot, oprot TProtocol
}
// TStandardClient implements TClient, and uses the standard message format for Thrift.
// It is not safe for concurrent use.
func NewTStandardClient(inputProtocol, outputProtocol TProtocol) *TStandardClient {
return &TStandardClient{
iprot: inputProtocol,
oprot: outputProtocol,
}
}
func (p *TStandardClient) Send(ctx context.Context, oprot TProtocol, seqId int32, method string, args TStruct) error {
if err := oprot.WriteMessageBegin(method, CALL, seqId); err != nil {
return err
}
if err := args.Write(oprot); err != nil {
return err
}
if err := oprot.WriteMessageEnd(); err != nil {
return err
}
return oprot.Flush(ctx)
}
func (p *TStandardClient) Recv(iprot TProtocol, seqId int32, method string, result TStruct) error {
rMethod, rTypeId, rSeqId, err := iprot.ReadMessageBegin()
if err != nil {
return err
}
if method != rMethod {
return NewTApplicationException(WRONG_METHOD_NAME, fmt.Sprintf("%s: wrong method name", method))
} else if seqId != rSeqId {
return NewTApplicationException(BAD_SEQUENCE_ID, fmt.Sprintf("%s: out of order sequence response", method))
} else if rTypeId == EXCEPTION {
var exception tApplicationException
if err := exception.Read(iprot); err != nil {
return err
}
if err := iprot.ReadMessageEnd(); err != nil {
return err
}
return &exception
} else if rTypeId != REPLY {
return NewTApplicationException(INVALID_MESSAGE_TYPE_EXCEPTION, fmt.Sprintf("%s: invalid message type", method))
}
if err := result.Read(iprot); err != nil {
return err
}
return iprot.ReadMessageEnd()
}
func (p *TStandardClient) Call(ctx context.Context, method string, args, result TStruct) error {
p.seqId++
seqId := p.seqId
if err := p.Send(ctx, p.oprot, seqId, method, args); err != nil {
return err
}
// method is oneway
if result == nil {
return nil
}
return p.Recv(p.iprot, seqId, method, result)
}

View file

@ -1,30 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import "context"
type mockProcessor struct {
ProcessFunc func(in, out TProtocol) (bool, TException)
}
func (m *mockProcessor) Process(ctx context.Context, in, out TProtocol) (bool, TException) {
return m.ProcessFunc(in, out)
}

View file

@ -1,810 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"context"
"encoding/binary"
"fmt"
"io"
"math"
)
const (
COMPACT_PROTOCOL_ID = 0x082
COMPACT_VERSION = 1
COMPACT_VERSION_MASK = 0x1f
COMPACT_TYPE_MASK = 0x0E0
COMPACT_TYPE_BITS = 0x07
COMPACT_TYPE_SHIFT_AMOUNT = 5
)
type tCompactType byte
const (
COMPACT_BOOLEAN_TRUE = 0x01
COMPACT_BOOLEAN_FALSE = 0x02
COMPACT_BYTE = 0x03
COMPACT_I16 = 0x04
COMPACT_I32 = 0x05
COMPACT_I64 = 0x06
COMPACT_DOUBLE = 0x07
COMPACT_BINARY = 0x08
COMPACT_LIST = 0x09
COMPACT_SET = 0x0A
COMPACT_MAP = 0x0B
COMPACT_STRUCT = 0x0C
)
var (
ttypeToCompactType map[TType]tCompactType
)
func init() {
ttypeToCompactType = map[TType]tCompactType{
STOP: STOP,
BOOL: COMPACT_BOOLEAN_TRUE,
BYTE: COMPACT_BYTE,
I16: COMPACT_I16,
I32: COMPACT_I32,
I64: COMPACT_I64,
DOUBLE: COMPACT_DOUBLE,
STRING: COMPACT_BINARY,
LIST: COMPACT_LIST,
SET: COMPACT_SET,
MAP: COMPACT_MAP,
STRUCT: COMPACT_STRUCT,
}
}
type TCompactProtocolFactory struct{}
func NewTCompactProtocolFactory() *TCompactProtocolFactory {
return &TCompactProtocolFactory{}
}
func (p *TCompactProtocolFactory) GetProtocol(trans TTransport) TProtocol {
return NewTCompactProtocol(trans)
}
type TCompactProtocol struct {
trans TRichTransport
origTransport TTransport
// Used to keep track of the last field for the current and previous structs,
// so we can do the delta stuff.
lastField []int
lastFieldId int
// If we encounter a boolean field begin, save the TField here so it can
// have the value incorporated.
booleanFieldName string
booleanFieldId int16
booleanFieldPending bool
// If we read a field header, and it's a boolean field, save the boolean
// value here so that readBool can use it.
boolValue bool
boolValueIsNotNull bool
buffer [64]byte
}
// Create a TCompactProtocol given a TTransport
func NewTCompactProtocol(trans TTransport) *TCompactProtocol {
p := &TCompactProtocol{origTransport: trans, lastField: []int{}}
if et, ok := trans.(TRichTransport); ok {
p.trans = et
} else {
p.trans = NewTRichTransport(trans)
}
return p
}
//
// Public Writing methods.
//
// Write a message header to the wire. Compact Protocol messages contain the
// protocol version so we can migrate forwards in the future if need be.
func (p *TCompactProtocol) WriteMessageBegin(name string, typeId TMessageType, seqid int32) error {
err := p.writeByteDirect(COMPACT_PROTOCOL_ID)
if err != nil {
return NewTProtocolException(err)
}
err = p.writeByteDirect((COMPACT_VERSION & COMPACT_VERSION_MASK) | ((byte(typeId) << COMPACT_TYPE_SHIFT_AMOUNT) & COMPACT_TYPE_MASK))
if err != nil {
return NewTProtocolException(err)
}
_, err = p.writeVarint32(seqid)
if err != nil {
return NewTProtocolException(err)
}
e := p.WriteString(name)
return e
}
func (p *TCompactProtocol) WriteMessageEnd() error { return nil }
// Write a struct begin. This doesn't actually put anything on the wire. We
// use it as an opportunity to put special placeholder markers on the field
// stack so we can get the field id deltas correct.
func (p *TCompactProtocol) WriteStructBegin(name string) error {
p.lastField = append(p.lastField, p.lastFieldId)
p.lastFieldId = 0
return nil
}
// Write a struct end. This doesn't actually put anything on the wire. We use
// this as an opportunity to pop the last field from the current struct off
// of the field stack.
func (p *TCompactProtocol) WriteStructEnd() error {
p.lastFieldId = p.lastField[len(p.lastField)-1]
p.lastField = p.lastField[:len(p.lastField)-1]
return nil
}
func (p *TCompactProtocol) WriteFieldBegin(name string, typeId TType, id int16) error {
if typeId == BOOL {
// we want to possibly include the value, so we'll wait.
p.booleanFieldName, p.booleanFieldId, p.booleanFieldPending = name, id, true
return nil
}
_, err := p.writeFieldBeginInternal(name, typeId, id, 0xFF)
return NewTProtocolException(err)
}
// The workhorse of writeFieldBegin. It has the option of doing a
// 'type override' of the type header. This is used specifically in the
// boolean field case.
func (p *TCompactProtocol) writeFieldBeginInternal(name string, typeId TType, id int16, typeOverride byte) (int, error) {
// short lastField = lastField_.pop();
// if there's a type override, use that.
var typeToWrite byte
if typeOverride == 0xFF {
typeToWrite = byte(p.getCompactType(typeId))
} else {
typeToWrite = typeOverride
}
// check if we can use delta encoding for the field id
fieldId := int(id)
written := 0
if fieldId > p.lastFieldId && fieldId-p.lastFieldId <= 15 {
// write them together
err := p.writeByteDirect(byte((fieldId-p.lastFieldId)<<4) | typeToWrite)
if err != nil {
return 0, err
}
} else {
// write them separate
err := p.writeByteDirect(typeToWrite)
if err != nil {
return 0, err
}
err = p.WriteI16(id)
written = 1 + 2
if err != nil {
return 0, err
}
}
p.lastFieldId = fieldId
// p.lastField.Push(field.id);
return written, nil
}
func (p *TCompactProtocol) WriteFieldEnd() error { return nil }
func (p *TCompactProtocol) WriteFieldStop() error {
err := p.writeByteDirect(STOP)
return NewTProtocolException(err)
}
func (p *TCompactProtocol) WriteMapBegin(keyType TType, valueType TType, size int) error {
if size == 0 {
err := p.writeByteDirect(0)
return NewTProtocolException(err)
}
_, err := p.writeVarint32(int32(size))
if err != nil {
return NewTProtocolException(err)
}
err = p.writeByteDirect(byte(p.getCompactType(keyType))<<4 | byte(p.getCompactType(valueType)))
return NewTProtocolException(err)
}
func (p *TCompactProtocol) WriteMapEnd() error { return nil }
// Write a list header.
func (p *TCompactProtocol) WriteListBegin(elemType TType, size int) error {
_, err := p.writeCollectionBegin(elemType, size)
return NewTProtocolException(err)
}
func (p *TCompactProtocol) WriteListEnd() error { return nil }
// Write a set header.
func (p *TCompactProtocol) WriteSetBegin(elemType TType, size int) error {
_, err := p.writeCollectionBegin(elemType, size)
return NewTProtocolException(err)
}
func (p *TCompactProtocol) WriteSetEnd() error { return nil }
func (p *TCompactProtocol) WriteBool(value bool) error {
v := byte(COMPACT_BOOLEAN_FALSE)
if value {
v = byte(COMPACT_BOOLEAN_TRUE)
}
if p.booleanFieldPending {
// we haven't written the field header yet
_, err := p.writeFieldBeginInternal(p.booleanFieldName, BOOL, p.booleanFieldId, v)
p.booleanFieldPending = false
return NewTProtocolException(err)
}
// we're not part of a field, so just write the value.
err := p.writeByteDirect(v)
return NewTProtocolException(err)
}
// Write a byte. Nothing to see here!
func (p *TCompactProtocol) WriteByte(value int8) error {
err := p.writeByteDirect(byte(value))
return NewTProtocolException(err)
}
// Write an I16 as a zigzag varint.
func (p *TCompactProtocol) WriteI16(value int16) error {
_, err := p.writeVarint32(p.int32ToZigzag(int32(value)))
return NewTProtocolException(err)
}
// Write an i32 as a zigzag varint.
func (p *TCompactProtocol) WriteI32(value int32) error {
_, err := p.writeVarint32(p.int32ToZigzag(value))
return NewTProtocolException(err)
}
// Write an i64 as a zigzag varint.
func (p *TCompactProtocol) WriteI64(value int64) error {
_, err := p.writeVarint64(p.int64ToZigzag(value))
return NewTProtocolException(err)
}
// Write a double to the wire as 8 bytes.
func (p *TCompactProtocol) WriteDouble(value float64) error {
buf := p.buffer[0:8]
binary.LittleEndian.PutUint64(buf, math.Float64bits(value))
_, err := p.trans.Write(buf)
return NewTProtocolException(err)
}
// Write a string to the wire with a varint size preceding.
func (p *TCompactProtocol) WriteString(value string) error {
_, e := p.writeVarint32(int32(len(value)))
if e != nil {
return NewTProtocolException(e)
}
if len(value) > 0 {
}
_, e = p.trans.WriteString(value)
return e
}
// Write a byte array, using a varint for the size.
func (p *TCompactProtocol) WriteBinary(bin []byte) error {
_, e := p.writeVarint32(int32(len(bin)))
if e != nil {
return NewTProtocolException(e)
}
if len(bin) > 0 {
_, e = p.trans.Write(bin)
return NewTProtocolException(e)
}
return nil
}
//
// Reading methods.
//
// Read a message header.
func (p *TCompactProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqId int32, err error) {
protocolId, err := p.readByteDirect()
if err != nil {
return
}
if protocolId != COMPACT_PROTOCOL_ID {
e := fmt.Errorf("Expected protocol id %02x but got %02x", COMPACT_PROTOCOL_ID, protocolId)
return "", typeId, seqId, NewTProtocolExceptionWithType(BAD_VERSION, e)
}
versionAndType, err := p.readByteDirect()
if err != nil {
return
}
version := versionAndType & COMPACT_VERSION_MASK
typeId = TMessageType((versionAndType >> COMPACT_TYPE_SHIFT_AMOUNT) & COMPACT_TYPE_BITS)
if version != COMPACT_VERSION {
e := fmt.Errorf("Expected version %02x but got %02x", COMPACT_VERSION, version)
err = NewTProtocolExceptionWithType(BAD_VERSION, e)
return
}
seqId, e := p.readVarint32()
if e != nil {
err = NewTProtocolException(e)
return
}
name, err = p.ReadString()
return
}
func (p *TCompactProtocol) ReadMessageEnd() error { return nil }
// Read a struct begin. There's nothing on the wire for this, but it is our
// opportunity to push a new struct begin marker onto the field stack.
func (p *TCompactProtocol) ReadStructBegin() (name string, err error) {
p.lastField = append(p.lastField, p.lastFieldId)
p.lastFieldId = 0
return
}
// Doesn't actually consume any wire data, just removes the last field for
// this struct from the field stack.
func (p *TCompactProtocol) ReadStructEnd() error {
// consume the last field we read off the wire.
p.lastFieldId = p.lastField[len(p.lastField)-1]
p.lastField = p.lastField[:len(p.lastField)-1]
return nil
}
// Read a field header off the wire.
func (p *TCompactProtocol) ReadFieldBegin() (name string, typeId TType, id int16, err error) {
t, err := p.readByteDirect()
if err != nil {
return
}
// if it's a stop, then we can return immediately, as the struct is over.
if (t & 0x0f) == STOP {
return "", STOP, 0, nil
}
// mask off the 4 MSB of the type header. it could contain a field id delta.
modifier := int16((t & 0xf0) >> 4)
if modifier == 0 {
// not a delta. look ahead for the zigzag varint field id.
id, err = p.ReadI16()
if err != nil {
return
}
} else {
// has a delta. add the delta to the last read field id.
id = int16(p.lastFieldId) + modifier
}
typeId, e := p.getTType(tCompactType(t & 0x0f))
if e != nil {
err = NewTProtocolException(e)
return
}
// if this happens to be a boolean field, the value is encoded in the type
if p.isBoolType(t) {
// save the boolean value in a special instance variable.
p.boolValue = (byte(t)&0x0f == COMPACT_BOOLEAN_TRUE)
p.boolValueIsNotNull = true
}
// push the new field onto the field stack so we can keep the deltas going.
p.lastFieldId = int(id)
return
}
func (p *TCompactProtocol) ReadFieldEnd() error { return nil }
// Read a map header off the wire. If the size is zero, skip reading the key
// and value type. This means that 0-length maps will yield TMaps without the
// "correct" types.
func (p *TCompactProtocol) ReadMapBegin() (keyType TType, valueType TType, size int, err error) {
size32, e := p.readVarint32()
if e != nil {
err = NewTProtocolException(e)
return
}
if size32 < 0 {
err = invalidDataLength
return
}
size = int(size32)
keyAndValueType := byte(STOP)
if size != 0 {
keyAndValueType, err = p.readByteDirect()
if err != nil {
return
}
}
keyType, _ = p.getTType(tCompactType(keyAndValueType >> 4))
valueType, _ = p.getTType(tCompactType(keyAndValueType & 0xf))
return
}
func (p *TCompactProtocol) ReadMapEnd() error { return nil }
// Read a list header off the wire. If the list size is 0-14, the size will
// be packed into the element type header. If it's a longer list, the 4 MSB
// of the element type header will be 0xF, and a varint will follow with the
// true size.
func (p *TCompactProtocol) ReadListBegin() (elemType TType, size int, err error) {
size_and_type, err := p.readByteDirect()
if err != nil {
return
}
size = int((size_and_type >> 4) & 0x0f)
if size == 15 {
size2, e := p.readVarint32()
if e != nil {
err = NewTProtocolException(e)
return
}
if size2 < 0 {
err = invalidDataLength
return
}
size = int(size2)
}
elemType, e := p.getTType(tCompactType(size_and_type))
if e != nil {
err = NewTProtocolException(e)
return
}
return
}
func (p *TCompactProtocol) ReadListEnd() error { return nil }
// Read a set header off the wire. If the set size is 0-14, the size will
// be packed into the element type header. If it's a longer set, the 4 MSB
// of the element type header will be 0xF, and a varint will follow with the
// true size.
func (p *TCompactProtocol) ReadSetBegin() (elemType TType, size int, err error) {
return p.ReadListBegin()
}
func (p *TCompactProtocol) ReadSetEnd() error { return nil }
// Read a boolean off the wire. If this is a boolean field, the value should
// already have been read during readFieldBegin, so we'll just consume the
// pre-stored value. Otherwise, read a byte.
func (p *TCompactProtocol) ReadBool() (value bool, err error) {
if p.boolValueIsNotNull {
p.boolValueIsNotNull = false
return p.boolValue, nil
}
v, err := p.readByteDirect()
return v == COMPACT_BOOLEAN_TRUE, err
}
// Read a single byte off the wire. Nothing interesting here.
func (p *TCompactProtocol) ReadByte() (int8, error) {
v, err := p.readByteDirect()
if err != nil {
return 0, NewTProtocolException(err)
}
return int8(v), err
}
// Read an i16 from the wire as a zigzag varint.
func (p *TCompactProtocol) ReadI16() (value int16, err error) {
v, err := p.ReadI32()
return int16(v), err
}
// Read an i32 from the wire as a zigzag varint.
func (p *TCompactProtocol) ReadI32() (value int32, err error) {
v, e := p.readVarint32()
if e != nil {
return 0, NewTProtocolException(e)
}
value = p.zigzagToInt32(v)
return value, nil
}
// Read an i64 from the wire as a zigzag varint.
func (p *TCompactProtocol) ReadI64() (value int64, err error) {
v, e := p.readVarint64()
if e != nil {
return 0, NewTProtocolException(e)
}
value = p.zigzagToInt64(v)
return value, nil
}
// No magic here - just read a double off the wire.
func (p *TCompactProtocol) ReadDouble() (value float64, err error) {
longBits := p.buffer[0:8]
_, e := io.ReadFull(p.trans, longBits)
if e != nil {
return 0.0, NewTProtocolException(e)
}
return math.Float64frombits(p.bytesToUint64(longBits)), nil
}
// Reads a []byte (via readBinary), and then UTF-8 decodes it.
func (p *TCompactProtocol) ReadString() (value string, err error) {
length, e := p.readVarint32()
if e != nil {
return "", NewTProtocolException(e)
}
if length < 0 {
return "", invalidDataLength
}
if length == 0 {
return "", nil
}
var buf []byte
if length <= int32(len(p.buffer)) {
buf = p.buffer[0:length]
} else {
buf = make([]byte, length)
}
_, e = io.ReadFull(p.trans, buf)
return string(buf), NewTProtocolException(e)
}
// Read a []byte from the wire.
func (p *TCompactProtocol) ReadBinary() (value []byte, err error) {
length, e := p.readVarint32()
if e != nil {
return nil, NewTProtocolException(e)
}
if length == 0 {
return []byte{}, nil
}
if length < 0 {
return nil, invalidDataLength
}
buf := make([]byte, length)
_, e = io.ReadFull(p.trans, buf)
return buf, NewTProtocolException(e)
}
func (p *TCompactProtocol) Flush(ctx context.Context) (err error) {
return NewTProtocolException(p.trans.Flush(ctx))
}
func (p *TCompactProtocol) Skip(fieldType TType) (err error) {
return SkipDefaultDepth(p, fieldType)
}
func (p *TCompactProtocol) Transport() TTransport {
return p.origTransport
}
//
// Internal writing methods
//
// Abstract method for writing the start of lists and sets. List and sets on
// the wire differ only by the type indicator.
func (p *TCompactProtocol) writeCollectionBegin(elemType TType, size int) (int, error) {
if size <= 14 {
return 1, p.writeByteDirect(byte(int32(size<<4) | int32(p.getCompactType(elemType))))
}
err := p.writeByteDirect(0xf0 | byte(p.getCompactType(elemType)))
if err != nil {
return 0, err
}
m, err := p.writeVarint32(int32(size))
return 1 + m, err
}
// Write an i32 as a varint. Results in 1-5 bytes on the wire.
// TODO(pomack): make a permanent buffer like writeVarint64?
func (p *TCompactProtocol) writeVarint32(n int32) (int, error) {
i32buf := p.buffer[0:5]
idx := 0
for {
if (n & ^0x7F) == 0 {
i32buf[idx] = byte(n)
idx++
// p.writeByteDirect(byte(n));
break
// return;
} else {
i32buf[idx] = byte((n & 0x7F) | 0x80)
idx++
// p.writeByteDirect(byte(((n & 0x7F) | 0x80)));
u := uint32(n)
n = int32(u >> 7)
}
}
return p.trans.Write(i32buf[0:idx])
}
// Write an i64 as a varint. Results in 1-10 bytes on the wire.
func (p *TCompactProtocol) writeVarint64(n int64) (int, error) {
varint64out := p.buffer[0:10]
idx := 0
for {
if (n & ^0x7F) == 0 {
varint64out[idx] = byte(n)
idx++
break
} else {
varint64out[idx] = byte((n & 0x7F) | 0x80)
idx++
u := uint64(n)
n = int64(u >> 7)
}
}
return p.trans.Write(varint64out[0:idx])
}
// Convert l into a zigzag long. This allows negative numbers to be
// represented compactly as a varint.
func (p *TCompactProtocol) int64ToZigzag(l int64) int64 {
return (l << 1) ^ (l >> 63)
}
// Convert l into a zigzag long. This allows negative numbers to be
// represented compactly as a varint.
func (p *TCompactProtocol) int32ToZigzag(n int32) int32 {
return (n << 1) ^ (n >> 31)
}
func (p *TCompactProtocol) fixedUint64ToBytes(n uint64, buf []byte) {
binary.LittleEndian.PutUint64(buf, n)
}
func (p *TCompactProtocol) fixedInt64ToBytes(n int64, buf []byte) {
binary.LittleEndian.PutUint64(buf, uint64(n))
}
// Writes a byte without any possibility of all that field header nonsense.
// Used internally by other writing methods that know they need to write a byte.
func (p *TCompactProtocol) writeByteDirect(b byte) error {
return p.trans.WriteByte(b)
}
// Writes a byte without any possibility of all that field header nonsense.
func (p *TCompactProtocol) writeIntAsByteDirect(n int) (int, error) {
return 1, p.writeByteDirect(byte(n))
}
//
// Internal reading methods
//
// Read an i32 from the wire as a varint. The MSB of each byte is set
// if there is another byte to follow. This can read up to 5 bytes.
func (p *TCompactProtocol) readVarint32() (int32, error) {
// if the wire contains the right stuff, this will just truncate the i64 we
// read and get us the right sign.
v, err := p.readVarint64()
return int32(v), err
}
// Read an i64 from the wire as a proper varint. The MSB of each byte is set
// if there is another byte to follow. This can read up to 10 bytes.
func (p *TCompactProtocol) readVarint64() (int64, error) {
shift := uint(0)
result := int64(0)
for {
b, err := p.readByteDirect()
if err != nil {
return 0, err
}
result |= int64(b&0x7f) << shift
if (b & 0x80) != 0x80 {
break
}
shift += 7
}
return result, nil
}
// Read a byte, unlike ReadByte that reads Thrift-byte that is i8.
func (p *TCompactProtocol) readByteDirect() (byte, error) {
return p.trans.ReadByte()
}
//
// encoding helpers
//
// Convert from zigzag int to int.
func (p *TCompactProtocol) zigzagToInt32(n int32) int32 {
u := uint32(n)
return int32(u>>1) ^ -(n & 1)
}
// Convert from zigzag long to long.
func (p *TCompactProtocol) zigzagToInt64(n int64) int64 {
u := uint64(n)
return int64(u>>1) ^ -(n & 1)
}
// Note that it's important that the mask bytes are long literals,
// otherwise they'll default to ints, and when you shift an int left 56 bits,
// you just get a messed up int.
func (p *TCompactProtocol) bytesToInt64(b []byte) int64 {
return int64(binary.LittleEndian.Uint64(b))
}
// Note that it's important that the mask bytes are long literals,
// otherwise they'll default to ints, and when you shift an int left 56 bits,
// you just get a messed up int.
func (p *TCompactProtocol) bytesToUint64(b []byte) uint64 {
return binary.LittleEndian.Uint64(b)
}
//
// type testing and converting
//
func (p *TCompactProtocol) isBoolType(b byte) bool {
return (b&0x0f) == COMPACT_BOOLEAN_TRUE || (b&0x0f) == COMPACT_BOOLEAN_FALSE
}
// Given a tCompactType constant, convert it to its corresponding
// TType value.
func (p *TCompactProtocol) getTType(t tCompactType) (TType, error) {
switch byte(t) & 0x0f {
case STOP:
return STOP, nil
case COMPACT_BOOLEAN_FALSE, COMPACT_BOOLEAN_TRUE:
return BOOL, nil
case COMPACT_BYTE:
return BYTE, nil
case COMPACT_I16:
return I16, nil
case COMPACT_I32:
return I32, nil
case COMPACT_I64:
return I64, nil
case COMPACT_DOUBLE:
return DOUBLE, nil
case COMPACT_BINARY:
return STRING, nil
case COMPACT_LIST:
return LIST, nil
case COMPACT_SET:
return SET, nil
case COMPACT_MAP:
return MAP, nil
case COMPACT_STRUCT:
return STRUCT, nil
}
return STOP, TException(fmt.Errorf("don't know what type: %v", t&0x0f))
}
// Given a TType value, find the appropriate TCompactProtocol.Types constant.
func (p *TCompactProtocol) getCompactType(t TType) tCompactType {
return ttypeToCompactType[t]
}

View file

@ -1,60 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"bytes"
"testing"
)
func TestReadWriteCompactProtocol(t *testing.T) {
ReadWriteProtocolTest(t, NewTCompactProtocolFactory())
transports := []TTransport{
NewTMemoryBuffer(),
NewStreamTransportRW(bytes.NewBuffer(make([]byte, 0, 16384))),
NewTFramedTransport(NewTMemoryBuffer()),
}
zlib0, _ := NewTZlibTransport(NewTMemoryBuffer(), 0)
zlib6, _ := NewTZlibTransport(NewTMemoryBuffer(), 6)
zlib9, _ := NewTZlibTransport(NewTFramedTransport(NewTMemoryBuffer()), 9)
transports = append(transports, zlib0, zlib6, zlib9)
for _, trans := range transports {
p := NewTCompactProtocol(trans)
ReadWriteBool(t, p, trans)
p = NewTCompactProtocol(trans)
ReadWriteByte(t, p, trans)
p = NewTCompactProtocol(trans)
ReadWriteI16(t, p, trans)
p = NewTCompactProtocol(trans)
ReadWriteI32(t, p, trans)
p = NewTCompactProtocol(trans)
ReadWriteI64(t, p, trans)
p = NewTCompactProtocol(trans)
ReadWriteDouble(t, p, trans)
p = NewTCompactProtocol(trans)
ReadWriteString(t, p, trans)
p = NewTCompactProtocol(trans)
ReadWriteBinary(t, p, trans)
trans.Close()
}
}

View file

@ -1,24 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import "context"
var defaultCtx = context.Background()

View file

@ -1,270 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"context"
"log"
)
type TDebugProtocol struct {
Delegate TProtocol
LogPrefix string
}
type TDebugProtocolFactory struct {
Underlying TProtocolFactory
LogPrefix string
}
func NewTDebugProtocolFactory(underlying TProtocolFactory, logPrefix string) *TDebugProtocolFactory {
return &TDebugProtocolFactory{
Underlying: underlying,
LogPrefix: logPrefix,
}
}
func (t *TDebugProtocolFactory) GetProtocol(trans TTransport) TProtocol {
return &TDebugProtocol{
Delegate: t.Underlying.GetProtocol(trans),
LogPrefix: t.LogPrefix,
}
}
func (tdp *TDebugProtocol) WriteMessageBegin(name string, typeId TMessageType, seqid int32) error {
err := tdp.Delegate.WriteMessageBegin(name, typeId, seqid)
log.Printf("%sWriteMessageBegin(name=%#v, typeId=%#v, seqid=%#v) => %#v", tdp.LogPrefix, name, typeId, seqid, err)
return err
}
func (tdp *TDebugProtocol) WriteMessageEnd() error {
err := tdp.Delegate.WriteMessageEnd()
log.Printf("%sWriteMessageEnd() => %#v", tdp.LogPrefix, err)
return err
}
func (tdp *TDebugProtocol) WriteStructBegin(name string) error {
err := tdp.Delegate.WriteStructBegin(name)
log.Printf("%sWriteStructBegin(name=%#v) => %#v", tdp.LogPrefix, name, err)
return err
}
func (tdp *TDebugProtocol) WriteStructEnd() error {
err := tdp.Delegate.WriteStructEnd()
log.Printf("%sWriteStructEnd() => %#v", tdp.LogPrefix, err)
return err
}
func (tdp *TDebugProtocol) WriteFieldBegin(name string, typeId TType, id int16) error {
err := tdp.Delegate.WriteFieldBegin(name, typeId, id)
log.Printf("%sWriteFieldBegin(name=%#v, typeId=%#v, id%#v) => %#v", tdp.LogPrefix, name, typeId, id, err)
return err
}
func (tdp *TDebugProtocol) WriteFieldEnd() error {
err := tdp.Delegate.WriteFieldEnd()
log.Printf("%sWriteFieldEnd() => %#v", tdp.LogPrefix, err)
return err
}
func (tdp *TDebugProtocol) WriteFieldStop() error {
err := tdp.Delegate.WriteFieldStop()
log.Printf("%sWriteFieldStop() => %#v", tdp.LogPrefix, err)
return err
}
func (tdp *TDebugProtocol) WriteMapBegin(keyType TType, valueType TType, size int) error {
err := tdp.Delegate.WriteMapBegin(keyType, valueType, size)
log.Printf("%sWriteMapBegin(keyType=%#v, valueType=%#v, size=%#v) => %#v", tdp.LogPrefix, keyType, valueType, size, err)
return err
}
func (tdp *TDebugProtocol) WriteMapEnd() error {
err := tdp.Delegate.WriteMapEnd()
log.Printf("%sWriteMapEnd() => %#v", tdp.LogPrefix, err)
return err
}
func (tdp *TDebugProtocol) WriteListBegin(elemType TType, size int) error {
err := tdp.Delegate.WriteListBegin(elemType, size)
log.Printf("%sWriteListBegin(elemType=%#v, size=%#v) => %#v", tdp.LogPrefix, elemType, size, err)
return err
}
func (tdp *TDebugProtocol) WriteListEnd() error {
err := tdp.Delegate.WriteListEnd()
log.Printf("%sWriteListEnd() => %#v", tdp.LogPrefix, err)
return err
}
func (tdp *TDebugProtocol) WriteSetBegin(elemType TType, size int) error {
err := tdp.Delegate.WriteSetBegin(elemType, size)
log.Printf("%sWriteSetBegin(elemType=%#v, size=%#v) => %#v", tdp.LogPrefix, elemType, size, err)
return err
}
func (tdp *TDebugProtocol) WriteSetEnd() error {
err := tdp.Delegate.WriteSetEnd()
log.Printf("%sWriteSetEnd() => %#v", tdp.LogPrefix, err)
return err
}
func (tdp *TDebugProtocol) WriteBool(value bool) error {
err := tdp.Delegate.WriteBool(value)
log.Printf("%sWriteBool(value=%#v) => %#v", tdp.LogPrefix, value, err)
return err
}
func (tdp *TDebugProtocol) WriteByte(value int8) error {
err := tdp.Delegate.WriteByte(value)
log.Printf("%sWriteByte(value=%#v) => %#v", tdp.LogPrefix, value, err)
return err
}
func (tdp *TDebugProtocol) WriteI16(value int16) error {
err := tdp.Delegate.WriteI16(value)
log.Printf("%sWriteI16(value=%#v) => %#v", tdp.LogPrefix, value, err)
return err
}
func (tdp *TDebugProtocol) WriteI32(value int32) error {
err := tdp.Delegate.WriteI32(value)
log.Printf("%sWriteI32(value=%#v) => %#v", tdp.LogPrefix, value, err)
return err
}
func (tdp *TDebugProtocol) WriteI64(value int64) error {
err := tdp.Delegate.WriteI64(value)
log.Printf("%sWriteI64(value=%#v) => %#v", tdp.LogPrefix, value, err)
return err
}
func (tdp *TDebugProtocol) WriteDouble(value float64) error {
err := tdp.Delegate.WriteDouble(value)
log.Printf("%sWriteDouble(value=%#v) => %#v", tdp.LogPrefix, value, err)
return err
}
func (tdp *TDebugProtocol) WriteString(value string) error {
err := tdp.Delegate.WriteString(value)
log.Printf("%sWriteString(value=%#v) => %#v", tdp.LogPrefix, value, err)
return err
}
func (tdp *TDebugProtocol) WriteBinary(value []byte) error {
err := tdp.Delegate.WriteBinary(value)
log.Printf("%sWriteBinary(value=%#v) => %#v", tdp.LogPrefix, value, err)
return err
}
func (tdp *TDebugProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqid int32, err error) {
name, typeId, seqid, err = tdp.Delegate.ReadMessageBegin()
log.Printf("%sReadMessageBegin() (name=%#v, typeId=%#v, seqid=%#v, err=%#v)", tdp.LogPrefix, name, typeId, seqid, err)
return
}
func (tdp *TDebugProtocol) ReadMessageEnd() (err error) {
err = tdp.Delegate.ReadMessageEnd()
log.Printf("%sReadMessageEnd() err=%#v", tdp.LogPrefix, err)
return
}
func (tdp *TDebugProtocol) ReadStructBegin() (name string, err error) {
name, err = tdp.Delegate.ReadStructBegin()
log.Printf("%sReadStructBegin() (name%#v, err=%#v)", tdp.LogPrefix, name, err)
return
}
func (tdp *TDebugProtocol) ReadStructEnd() (err error) {
err = tdp.Delegate.ReadStructEnd()
log.Printf("%sReadStructEnd() err=%#v", tdp.LogPrefix, err)
return
}
func (tdp *TDebugProtocol) ReadFieldBegin() (name string, typeId TType, id int16, err error) {
name, typeId, id, err = tdp.Delegate.ReadFieldBegin()
log.Printf("%sReadFieldBegin() (name=%#v, typeId=%#v, id=%#v, err=%#v)", tdp.LogPrefix, name, typeId, id, err)
return
}
func (tdp *TDebugProtocol) ReadFieldEnd() (err error) {
err = tdp.Delegate.ReadFieldEnd()
log.Printf("%sReadFieldEnd() err=%#v", tdp.LogPrefix, err)
return
}
func (tdp *TDebugProtocol) ReadMapBegin() (keyType TType, valueType TType, size int, err error) {
keyType, valueType, size, err = tdp.Delegate.ReadMapBegin()
log.Printf("%sReadMapBegin() (keyType=%#v, valueType=%#v, size=%#v, err=%#v)", tdp.LogPrefix, keyType, valueType, size, err)
return
}
func (tdp *TDebugProtocol) ReadMapEnd() (err error) {
err = tdp.Delegate.ReadMapEnd()
log.Printf("%sReadMapEnd() err=%#v", tdp.LogPrefix, err)
return
}
func (tdp *TDebugProtocol) ReadListBegin() (elemType TType, size int, err error) {
elemType, size, err = tdp.Delegate.ReadListBegin()
log.Printf("%sReadListBegin() (elemType=%#v, size=%#v, err=%#v)", tdp.LogPrefix, elemType, size, err)
return
}
func (tdp *TDebugProtocol) ReadListEnd() (err error) {
err = tdp.Delegate.ReadListEnd()
log.Printf("%sReadListEnd() err=%#v", tdp.LogPrefix, err)
return
}
func (tdp *TDebugProtocol) ReadSetBegin() (elemType TType, size int, err error) {
elemType, size, err = tdp.Delegate.ReadSetBegin()
log.Printf("%sReadSetBegin() (elemType=%#v, size=%#v, err=%#v)", tdp.LogPrefix, elemType, size, err)
return
}
func (tdp *TDebugProtocol) ReadSetEnd() (err error) {
err = tdp.Delegate.ReadSetEnd()
log.Printf("%sReadSetEnd() err=%#v", tdp.LogPrefix, err)
return
}
func (tdp *TDebugProtocol) ReadBool() (value bool, err error) {
value, err = tdp.Delegate.ReadBool()
log.Printf("%sReadBool() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
return
}
func (tdp *TDebugProtocol) ReadByte() (value int8, err error) {
value, err = tdp.Delegate.ReadByte()
log.Printf("%sReadByte() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
return
}
func (tdp *TDebugProtocol) ReadI16() (value int16, err error) {
value, err = tdp.Delegate.ReadI16()
log.Printf("%sReadI16() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
return
}
func (tdp *TDebugProtocol) ReadI32() (value int32, err error) {
value, err = tdp.Delegate.ReadI32()
log.Printf("%sReadI32() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
return
}
func (tdp *TDebugProtocol) ReadI64() (value int64, err error) {
value, err = tdp.Delegate.ReadI64()
log.Printf("%sReadI64() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
return
}
func (tdp *TDebugProtocol) ReadDouble() (value float64, err error) {
value, err = tdp.Delegate.ReadDouble()
log.Printf("%sReadDouble() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
return
}
func (tdp *TDebugProtocol) ReadString() (value string, err error) {
value, err = tdp.Delegate.ReadString()
log.Printf("%sReadString() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
return
}
func (tdp *TDebugProtocol) ReadBinary() (value []byte, err error) {
value, err = tdp.Delegate.ReadBinary()
log.Printf("%sReadBinary() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
return
}
func (tdp *TDebugProtocol) Skip(fieldType TType) (err error) {
err = tdp.Delegate.Skip(fieldType)
log.Printf("%sSkip(fieldType=%#v) (err=%#v)", tdp.LogPrefix, fieldType, err)
return
}
func (tdp *TDebugProtocol) Flush(ctx context.Context) (err error) {
err = tdp.Delegate.Flush(ctx)
log.Printf("%sFlush() (err=%#v)", tdp.LogPrefix, err)
return
}
func (tdp *TDebugProtocol) Transport() TTransport {
return tdp.Delegate.Transport()
}

View file

@ -1,58 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
type TDeserializer struct {
Transport TTransport
Protocol TProtocol
}
func NewTDeserializer() *TDeserializer {
var transport TTransport
transport = NewTMemoryBufferLen(1024)
protocol := NewTBinaryProtocolFactoryDefault().GetProtocol(transport)
return &TDeserializer{
transport,
protocol}
}
func (t *TDeserializer) ReadString(msg TStruct, s string) (err error) {
err = nil
if _, err = t.Transport.Write([]byte(s)); err != nil {
return
}
if err = msg.Read(t.Protocol); err != nil {
return
}
return
}
func (t *TDeserializer) Read(msg TStruct, b []byte) (err error) {
err = nil
if _, err = t.Transport.Write(b); err != nil {
return
}
if err = msg.Read(t.Protocol); err != nil {
return
}
return
}

View file

@ -1,44 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"errors"
)
// Generic Thrift exception
type TException interface {
error
}
// Prepends additional information to an error without losing the Thrift exception interface
func PrependError(prepend string, err error) error {
if t, ok := err.(TTransportException); ok {
return NewTTransportException(t.TypeId(), prepend+t.Error())
}
if t, ok := err.(TProtocolException); ok {
return NewTProtocolExceptionWithType(t.TypeId(), errors.New(prepend+err.Error()))
}
if t, ok := err.(TApplicationException); ok {
return NewTApplicationException(t.TypeId(), prepend+t.Error())
}
return errors.New(prepend + err.Error())
}

View file

@ -1,69 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"errors"
"testing"
)
func TestPrependError(t *testing.T) {
err := NewTApplicationException(INTERNAL_ERROR, "original error")
err2, ok := PrependError("Prepend: ", err).(TApplicationException)
if !ok {
t.Fatal("Couldn't cast error TApplicationException")
}
if err2.Error() != "Prepend: original error" {
t.Fatal("Unexpected error string")
}
if err2.TypeId() != INTERNAL_ERROR {
t.Fatal("Unexpected type error")
}
err3 := NewTProtocolExceptionWithType(INVALID_DATA, errors.New("original error"))
err4, ok := PrependError("Prepend: ", err3).(TProtocolException)
if !ok {
t.Fatal("Couldn't cast error TProtocolException")
}
if err4.Error() != "Prepend: original error" {
t.Fatal("Unexpected error string")
}
if err4.TypeId() != INVALID_DATA {
t.Fatal("Unexpected type error")
}
err5 := NewTTransportException(TIMED_OUT, "original error")
err6, ok := PrependError("Prepend: ", err5).(TTransportException)
if !ok {
t.Fatal("Couldn't cast error TTransportException")
}
if err6.Error() != "Prepend: original error" {
t.Fatal("Unexpected error string")
}
if err6.TypeId() != TIMED_OUT {
t.Fatal("Unexpected type error")
}
err7 := errors.New("original error")
err8 := PrependError("Prepend: ", err7)
if err8.Error() != "Prepend: original error" {
t.Fatal("Unexpected error string")
}
}

View file

@ -1,79 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
// Helper class that encapsulates field metadata.
type field struct {
name string
typeId TType
id int
}
func newField(n string, t TType, i int) *field {
return &field{name: n, typeId: t, id: i}
}
func (p *field) Name() string {
if p == nil {
return ""
}
return p.name
}
func (p *field) TypeId() TType {
if p == nil {
return TType(VOID)
}
return p.typeId
}
func (p *field) Id() int {
if p == nil {
return -1
}
return p.id
}
func (p *field) String() string {
if p == nil {
return "<nil>"
}
return "<TField name:'" + p.name + "' type:" + string(p.typeId) + " field-id:" + string(p.id) + ">"
}
var ANONYMOUS_FIELD *field
type fieldSlice []field
func (p fieldSlice) Len() int {
return len(p)
}
func (p fieldSlice) Less(i, j int) bool {
return p[i].Id() < p[j].Id()
}
func (p fieldSlice) Swap(i, j int) {
p[i], p[j] = p[j], p[i]
}
func init() {
ANONYMOUS_FIELD = newField("", STOP, 0)
}

View file

@ -1,173 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"bufio"
"bytes"
"context"
"encoding/binary"
"fmt"
"io"
)
const DEFAULT_MAX_LENGTH = 16384000
type TFramedTransport struct {
transport TTransport
buf bytes.Buffer
reader *bufio.Reader
frameSize uint32 //Current remaining size of the frame. if ==0 read next frame header
buffer [4]byte
maxLength uint32
}
type tFramedTransportFactory struct {
factory TTransportFactory
maxLength uint32
}
func NewTFramedTransportFactory(factory TTransportFactory) TTransportFactory {
return &tFramedTransportFactory{factory: factory, maxLength: DEFAULT_MAX_LENGTH}
}
func NewTFramedTransportFactoryMaxLength(factory TTransportFactory, maxLength uint32) TTransportFactory {
return &tFramedTransportFactory{factory: factory, maxLength: maxLength}
}
func (p *tFramedTransportFactory) GetTransport(base TTransport) (TTransport, error) {
tt, err := p.factory.GetTransport(base)
if err != nil {
return nil, err
}
return NewTFramedTransportMaxLength(tt, p.maxLength), nil
}
func NewTFramedTransport(transport TTransport) *TFramedTransport {
return &TFramedTransport{transport: transport, reader: bufio.NewReader(transport), maxLength: DEFAULT_MAX_LENGTH}
}
func NewTFramedTransportMaxLength(transport TTransport, maxLength uint32) *TFramedTransport {
return &TFramedTransport{transport: transport, reader: bufio.NewReader(transport), maxLength: maxLength}
}
func (p *TFramedTransport) Open() error {
return p.transport.Open()
}
func (p *TFramedTransport) IsOpen() bool {
return p.transport.IsOpen()
}
func (p *TFramedTransport) Close() error {
return p.transport.Close()
}
func (p *TFramedTransport) Read(buf []byte) (l int, err error) {
if p.frameSize == 0 {
p.frameSize, err = p.readFrameHeader()
if err != nil {
return
}
}
if p.frameSize < uint32(len(buf)) {
frameSize := p.frameSize
tmp := make([]byte, p.frameSize)
l, err = p.Read(tmp)
copy(buf, tmp)
if err == nil {
err = NewTTransportExceptionFromError(fmt.Errorf("Not enough frame size %d to read %d bytes", frameSize, len(buf)))
return
}
}
got, err := p.reader.Read(buf)
p.frameSize = p.frameSize - uint32(got)
//sanity check
if p.frameSize < 0 {
return 0, NewTTransportException(UNKNOWN_TRANSPORT_EXCEPTION, "Negative frame size")
}
return got, NewTTransportExceptionFromError(err)
}
func (p *TFramedTransport) ReadByte() (c byte, err error) {
if p.frameSize == 0 {
p.frameSize, err = p.readFrameHeader()
if err != nil {
return
}
}
if p.frameSize < 1 {
return 0, NewTTransportExceptionFromError(fmt.Errorf("Not enough frame size %d to read %d bytes", p.frameSize, 1))
}
c, err = p.reader.ReadByte()
if err == nil {
p.frameSize--
}
return
}
func (p *TFramedTransport) Write(buf []byte) (int, error) {
n, err := p.buf.Write(buf)
return n, NewTTransportExceptionFromError(err)
}
func (p *TFramedTransport) WriteByte(c byte) error {
return p.buf.WriteByte(c)
}
func (p *TFramedTransport) WriteString(s string) (n int, err error) {
return p.buf.WriteString(s)
}
func (p *TFramedTransport) Flush(ctx context.Context) error {
size := p.buf.Len()
buf := p.buffer[:4]
binary.BigEndian.PutUint32(buf, uint32(size))
_, err := p.transport.Write(buf)
if err != nil {
p.buf.Truncate(0)
return NewTTransportExceptionFromError(err)
}
if size > 0 {
if n, err := p.buf.WriteTo(p.transport); err != nil {
print("Error while flushing write buffer of size ", size, " to transport, only wrote ", n, " bytes: ", err.Error(), "\n")
p.buf.Truncate(0)
return NewTTransportExceptionFromError(err)
}
}
err = p.transport.Flush(ctx)
return NewTTransportExceptionFromError(err)
}
func (p *TFramedTransport) readFrameHeader() (uint32, error) {
buf := p.buffer[:4]
if _, err := io.ReadFull(p.reader, buf); err != nil {
return 0, err
}
size := binary.BigEndian.Uint32(buf)
if size < 0 || size > p.maxLength {
return 0, NewTTransportException(UNKNOWN_TRANSPORT_EXCEPTION, fmt.Sprintf("Incorrect frame size (%d)", size))
}
return size, nil
}
func (p *TFramedTransport) RemainingBytes() (num_bytes uint64) {
return uint64(p.frameSize)
}

View file

@ -1,29 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"testing"
)
func TestFramedTransport(t *testing.T) {
trans := NewTFramedTransport(NewTMemoryBuffer())
TransportTest(t, trans, trans)
}

View file

@ -1,242 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"bytes"
"context"
"io"
"io/ioutil"
"net/http"
"net/url"
"strconv"
)
// Default to using the shared http client. Library users are
// free to change this global client or specify one through
// THttpClientOptions.
var DefaultHttpClient *http.Client = http.DefaultClient
type THttpClient struct {
client *http.Client
response *http.Response
url *url.URL
requestBuffer *bytes.Buffer
header http.Header
nsecConnectTimeout int64
nsecReadTimeout int64
}
type THttpClientTransportFactory struct {
options THttpClientOptions
url string
}
func (p *THttpClientTransportFactory) GetTransport(trans TTransport) (TTransport, error) {
if trans != nil {
t, ok := trans.(*THttpClient)
if ok && t.url != nil {
return NewTHttpClientWithOptions(t.url.String(), p.options)
}
}
return NewTHttpClientWithOptions(p.url, p.options)
}
type THttpClientOptions struct {
// If nil, DefaultHttpClient is used
Client *http.Client
}
func NewTHttpClientTransportFactory(url string) *THttpClientTransportFactory {
return NewTHttpClientTransportFactoryWithOptions(url, THttpClientOptions{})
}
func NewTHttpClientTransportFactoryWithOptions(url string, options THttpClientOptions) *THttpClientTransportFactory {
return &THttpClientTransportFactory{url: url, options: options}
}
func NewTHttpClientWithOptions(urlstr string, options THttpClientOptions) (TTransport, error) {
parsedURL, err := url.Parse(urlstr)
if err != nil {
return nil, err
}
buf := make([]byte, 0, 1024)
client := options.Client
if client == nil {
client = DefaultHttpClient
}
httpHeader := map[string][]string{"Content-Type": {"application/x-thrift"}}
return &THttpClient{client: client, url: parsedURL, requestBuffer: bytes.NewBuffer(buf), header: httpHeader}, nil
}
func NewTHttpClient(urlstr string) (TTransport, error) {
return NewTHttpClientWithOptions(urlstr, THttpClientOptions{})
}
// Set the HTTP Header for this specific Thrift Transport
// It is important that you first assert the TTransport as a THttpClient type
// like so:
//
// httpTrans := trans.(THttpClient)
// httpTrans.SetHeader("User-Agent","Thrift Client 1.0")
func (p *THttpClient) SetHeader(key string, value string) {
p.header.Add(key, value)
}
// Get the HTTP Header represented by the supplied Header Key for this specific Thrift Transport
// It is important that you first assert the TTransport as a THttpClient type
// like so:
//
// httpTrans := trans.(THttpClient)
// hdrValue := httpTrans.GetHeader("User-Agent")
func (p *THttpClient) GetHeader(key string) string {
return p.header.Get(key)
}
// Deletes the HTTP Header given a Header Key for this specific Thrift Transport
// It is important that you first assert the TTransport as a THttpClient type
// like so:
//
// httpTrans := trans.(THttpClient)
// httpTrans.DelHeader("User-Agent")
func (p *THttpClient) DelHeader(key string) {
p.header.Del(key)
}
func (p *THttpClient) Open() error {
// do nothing
return nil
}
func (p *THttpClient) IsOpen() bool {
return p.response != nil || p.requestBuffer != nil
}
func (p *THttpClient) closeResponse() error {
var err error
if p.response != nil && p.response.Body != nil {
// The docs specify that if keepalive is enabled and the response body is not
// read to completion the connection will never be returned to the pool and
// reused. Errors are being ignored here because if the connection is invalid
// and this fails for some reason, the Close() method will do any remaining
// cleanup.
io.Copy(ioutil.Discard, p.response.Body)
err = p.response.Body.Close()
}
p.response = nil
return err
}
func (p *THttpClient) Close() error {
if p.requestBuffer != nil {
p.requestBuffer.Reset()
p.requestBuffer = nil
}
return p.closeResponse()
}
func (p *THttpClient) Read(buf []byte) (int, error) {
if p.response == nil {
return 0, NewTTransportException(NOT_OPEN, "Response buffer is empty, no request.")
}
n, err := p.response.Body.Read(buf)
if n > 0 && (err == nil || err == io.EOF) {
return n, nil
}
return n, NewTTransportExceptionFromError(err)
}
func (p *THttpClient) ReadByte() (c byte, err error) {
return readByte(p.response.Body)
}
func (p *THttpClient) Write(buf []byte) (int, error) {
n, err := p.requestBuffer.Write(buf)
return n, err
}
func (p *THttpClient) WriteByte(c byte) error {
return p.requestBuffer.WriteByte(c)
}
func (p *THttpClient) WriteString(s string) (n int, err error) {
return p.requestBuffer.WriteString(s)
}
func (p *THttpClient) Flush(ctx context.Context) error {
// Close any previous response body to avoid leaking connections.
p.closeResponse()
req, err := http.NewRequest("POST", p.url.String(), p.requestBuffer)
if err != nil {
return NewTTransportExceptionFromError(err)
}
req.Header = p.header
if ctx != nil {
req = req.WithContext(ctx)
}
response, err := p.client.Do(req)
if err != nil {
return NewTTransportExceptionFromError(err)
}
if response.StatusCode != http.StatusOK {
// Close the response to avoid leaking file descriptors. closeResponse does
// more than just call Close(), so temporarily assign it and reuse the logic.
p.response = response
p.closeResponse()
// TODO(pomack) log bad response
return NewTTransportException(UNKNOWN_TRANSPORT_EXCEPTION, "HTTP Response code: "+strconv.Itoa(response.StatusCode))
}
p.response = response
return nil
}
func (p *THttpClient) RemainingBytes() (num_bytes uint64) {
len := p.response.ContentLength
if len >= 0 {
return uint64(len)
}
const maxSize = ^uint64(0)
return maxSize // the thruth is, we just don't know unless framed is used
}
// Deprecated: Use NewTHttpClientTransportFactory instead.
func NewTHttpPostClientTransportFactory(url string) *THttpClientTransportFactory {
return NewTHttpClientTransportFactoryWithOptions(url, THttpClientOptions{})
}
// Deprecated: Use NewTHttpClientTransportFactoryWithOptions instead.
func NewTHttpPostClientTransportFactoryWithOptions(url string, options THttpClientOptions) *THttpClientTransportFactory {
return NewTHttpClientTransportFactoryWithOptions(url, options)
}
// Deprecated: Use NewTHttpClientWithOptions instead.
func NewTHttpPostClientWithOptions(urlstr string, options THttpClientOptions) (TTransport, error) {
return NewTHttpClientWithOptions(urlstr, options)
}
// Deprecated: Use NewTHttpClient instead.
func NewTHttpPostClient(urlstr string) (TTransport, error) {
return NewTHttpClientWithOptions(urlstr, THttpClientOptions{})
}

View file

@ -1,106 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"net/http"
"testing"
)
func TestHttpClient(t *testing.T) {
l, addr := HttpClientSetupForTest(t)
if l != nil {
defer l.Close()
}
trans, err := NewTHttpPostClient("http://" + addr.String())
if err != nil {
l.Close()
t.Fatalf("Unable to connect to %s: %s", addr.String(), err)
}
TransportTest(t, trans, trans)
}
func TestHttpClientHeaders(t *testing.T) {
l, addr := HttpClientSetupForTest(t)
if l != nil {
defer l.Close()
}
trans, err := NewTHttpPostClient("http://" + addr.String())
if err != nil {
l.Close()
t.Fatalf("Unable to connect to %s: %s", addr.String(), err)
}
TransportHeaderTest(t, trans, trans)
}
func TestHttpCustomClient(t *testing.T) {
l, addr := HttpClientSetupForTest(t)
if l != nil {
defer l.Close()
}
httpTransport := &customHttpTransport{}
trans, err := NewTHttpPostClientWithOptions("http://"+addr.String(), THttpClientOptions{
Client: &http.Client{
Transport: httpTransport,
},
})
if err != nil {
l.Close()
t.Fatalf("Unable to connect to %s: %s", addr.String(), err)
}
TransportHeaderTest(t, trans, trans)
if !httpTransport.hit {
t.Fatalf("Custom client was not used")
}
}
func TestHttpCustomClientPackageScope(t *testing.T) {
l, addr := HttpClientSetupForTest(t)
if l != nil {
defer l.Close()
}
httpTransport := &customHttpTransport{}
DefaultHttpClient = &http.Client{
Transport: httpTransport,
}
trans, err := NewTHttpPostClient("http://" + addr.String())
if err != nil {
l.Close()
t.Fatalf("Unable to connect to %s: %s", addr.String(), err)
}
TransportHeaderTest(t, trans, trans)
if !httpTransport.hit {
t.Fatalf("Custom client was not used")
}
}
type customHttpTransport struct {
hit bool
}
func (c *customHttpTransport) RoundTrip(req *http.Request) (*http.Response, error) {
c.hit = true
return http.DefaultTransport.RoundTrip(req)
}

View file

@ -1,63 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"compress/gzip"
"io"
"net/http"
"strings"
)
// NewThriftHandlerFunc is a function that create a ready to use Apache Thrift Handler function
func NewThriftHandlerFunc(processor TProcessor,
inPfactory, outPfactory TProtocolFactory) func(w http.ResponseWriter, r *http.Request) {
return gz(func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Content-Type", "application/x-thrift")
transport := NewStreamTransport(r.Body, w)
processor.Process(r.Context(), inPfactory.GetProtocol(transport), outPfactory.GetProtocol(transport))
})
}
// gz transparently compresses the HTTP response if the client supports it.
func gz(handler http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
handler(w, r)
return
}
w.Header().Set("Content-Encoding", "gzip")
gz := gzip.NewWriter(w)
defer gz.Close()
gzw := gzipResponseWriter{Writer: gz, ResponseWriter: w}
handler(gzw, r)
}
}
type gzipResponseWriter struct {
io.Writer
http.ResponseWriter
}
func (w gzipResponseWriter) Write(b []byte) (int, error) {
return w.Writer.Write(b)
}

View file

@ -1,214 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"bufio"
"context"
"io"
)
// StreamTransport is a Transport made of an io.Reader and/or an io.Writer
type StreamTransport struct {
io.Reader
io.Writer
isReadWriter bool
closed bool
}
type StreamTransportFactory struct {
Reader io.Reader
Writer io.Writer
isReadWriter bool
}
func (p *StreamTransportFactory) GetTransport(trans TTransport) (TTransport, error) {
if trans != nil {
t, ok := trans.(*StreamTransport)
if ok {
if t.isReadWriter {
return NewStreamTransportRW(t.Reader.(io.ReadWriter)), nil
}
if t.Reader != nil && t.Writer != nil {
return NewStreamTransport(t.Reader, t.Writer), nil
}
if t.Reader != nil && t.Writer == nil {
return NewStreamTransportR(t.Reader), nil
}
if t.Reader == nil && t.Writer != nil {
return NewStreamTransportW(t.Writer), nil
}
return &StreamTransport{}, nil
}
}
if p.isReadWriter {
return NewStreamTransportRW(p.Reader.(io.ReadWriter)), nil
}
if p.Reader != nil && p.Writer != nil {
return NewStreamTransport(p.Reader, p.Writer), nil
}
if p.Reader != nil && p.Writer == nil {
return NewStreamTransportR(p.Reader), nil
}
if p.Reader == nil && p.Writer != nil {
return NewStreamTransportW(p.Writer), nil
}
return &StreamTransport{}, nil
}
func NewStreamTransportFactory(reader io.Reader, writer io.Writer, isReadWriter bool) *StreamTransportFactory {
return &StreamTransportFactory{Reader: reader, Writer: writer, isReadWriter: isReadWriter}
}
func NewStreamTransport(r io.Reader, w io.Writer) *StreamTransport {
return &StreamTransport{Reader: bufio.NewReader(r), Writer: bufio.NewWriter(w)}
}
func NewStreamTransportR(r io.Reader) *StreamTransport {
return &StreamTransport{Reader: bufio.NewReader(r)}
}
func NewStreamTransportW(w io.Writer) *StreamTransport {
return &StreamTransport{Writer: bufio.NewWriter(w)}
}
func NewStreamTransportRW(rw io.ReadWriter) *StreamTransport {
bufrw := bufio.NewReadWriter(bufio.NewReader(rw), bufio.NewWriter(rw))
return &StreamTransport{Reader: bufrw, Writer: bufrw, isReadWriter: true}
}
func (p *StreamTransport) IsOpen() bool {
return !p.closed
}
// implicitly opened on creation, can't be reopened once closed
func (p *StreamTransport) Open() error {
if !p.closed {
return NewTTransportException(ALREADY_OPEN, "StreamTransport already open.")
} else {
return NewTTransportException(NOT_OPEN, "cannot reopen StreamTransport.")
}
}
// Closes both the input and output streams.
func (p *StreamTransport) Close() error {
if p.closed {
return NewTTransportException(NOT_OPEN, "StreamTransport already closed.")
}
p.closed = true
closedReader := false
if p.Reader != nil {
c, ok := p.Reader.(io.Closer)
if ok {
e := c.Close()
closedReader = true
if e != nil {
return e
}
}
p.Reader = nil
}
if p.Writer != nil && (!closedReader || !p.isReadWriter) {
c, ok := p.Writer.(io.Closer)
if ok {
e := c.Close()
if e != nil {
return e
}
}
p.Writer = nil
}
return nil
}
// Flushes the underlying output stream if not null.
func (p *StreamTransport) Flush(ctx context.Context) error {
if p.Writer == nil {
return NewTTransportException(NOT_OPEN, "Cannot flush null outputStream")
}
f, ok := p.Writer.(Flusher)
if ok {
err := f.Flush()
if err != nil {
return NewTTransportExceptionFromError(err)
}
}
return nil
}
func (p *StreamTransport) Read(c []byte) (n int, err error) {
n, err = p.Reader.Read(c)
if err != nil {
err = NewTTransportExceptionFromError(err)
}
return
}
func (p *StreamTransport) ReadByte() (c byte, err error) {
f, ok := p.Reader.(io.ByteReader)
if ok {
c, err = f.ReadByte()
} else {
c, err = readByte(p.Reader)
}
if err != nil {
err = NewTTransportExceptionFromError(err)
}
return
}
func (p *StreamTransport) Write(c []byte) (n int, err error) {
n, err = p.Writer.Write(c)
if err != nil {
err = NewTTransportExceptionFromError(err)
}
return
}
func (p *StreamTransport) WriteByte(c byte) (err error) {
f, ok := p.Writer.(io.ByteWriter)
if ok {
err = f.WriteByte(c)
} else {
err = writeByte(p.Writer, c)
}
if err != nil {
err = NewTTransportExceptionFromError(err)
}
return
}
func (p *StreamTransport) WriteString(s string) (n int, err error) {
f, ok := p.Writer.(stringWriter)
if ok {
n, err = f.WriteString(s)
} else {
n, err = p.Writer.Write([]byte(s))
}
if err != nil {
err = NewTTransportExceptionFromError(err)
}
return
}
func (p *StreamTransport) RemainingBytes() (num_bytes uint64) {
const maxSize = ^uint64(0)
return maxSize // the thruth is, we just don't know unless framed is used
}

View file

@ -1,52 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"bytes"
"testing"
)
func TestStreamTransport(t *testing.T) {
trans := NewStreamTransportRW(bytes.NewBuffer(make([]byte, 0, 1024)))
TransportTest(t, trans, trans)
}
func TestStreamTransportOpenClose(t *testing.T) {
trans := NewStreamTransportRW(bytes.NewBuffer(make([]byte, 0, 1024)))
if !trans.IsOpen() {
t.Fatal("StreamTransport should be already open")
}
if trans.Open() == nil {
t.Fatal("StreamTransport should return error when open twice")
}
if trans.Close() != nil {
t.Fatal("StreamTransport should not return error when closing open transport")
}
if trans.IsOpen() {
t.Fatal("StreamTransport should not be open after close")
}
if trans.Close() == nil {
t.Fatal("StreamTransport should return error when closing a non open transport")
}
if trans.Open() == nil {
t.Fatal("StreamTransport should not be able to reopen")
}
}

View file

@ -1,584 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"context"
"encoding/base64"
"fmt"
)
const (
THRIFT_JSON_PROTOCOL_VERSION = 1
)
// for references to _ParseContext see tsimplejson_protocol.go
// JSON protocol implementation for thrift.
//
// This protocol produces/consumes a simple output format
// suitable for parsing by scripting languages. It should not be
// confused with the full-featured TJSONProtocol.
//
type TJSONProtocol struct {
*TSimpleJSONProtocol
}
// Constructor
func NewTJSONProtocol(t TTransport) *TJSONProtocol {
v := &TJSONProtocol{TSimpleJSONProtocol: NewTSimpleJSONProtocol(t)}
v.parseContextStack = append(v.parseContextStack, int(_CONTEXT_IN_TOPLEVEL))
v.dumpContext = append(v.dumpContext, int(_CONTEXT_IN_TOPLEVEL))
return v
}
// Factory
type TJSONProtocolFactory struct{}
func (p *TJSONProtocolFactory) GetProtocol(trans TTransport) TProtocol {
return NewTJSONProtocol(trans)
}
func NewTJSONProtocolFactory() *TJSONProtocolFactory {
return &TJSONProtocolFactory{}
}
func (p *TJSONProtocol) WriteMessageBegin(name string, typeId TMessageType, seqId int32) error {
p.resetContextStack() // THRIFT-3735
if e := p.OutputListBegin(); e != nil {
return e
}
if e := p.WriteI32(THRIFT_JSON_PROTOCOL_VERSION); e != nil {
return e
}
if e := p.WriteString(name); e != nil {
return e
}
if e := p.WriteByte(int8(typeId)); e != nil {
return e
}
if e := p.WriteI32(seqId); e != nil {
return e
}
return nil
}
func (p *TJSONProtocol) WriteMessageEnd() error {
return p.OutputListEnd()
}
func (p *TJSONProtocol) WriteStructBegin(name string) error {
if e := p.OutputObjectBegin(); e != nil {
return e
}
return nil
}
func (p *TJSONProtocol) WriteStructEnd() error {
return p.OutputObjectEnd()
}
func (p *TJSONProtocol) WriteFieldBegin(name string, typeId TType, id int16) error {
if e := p.WriteI16(id); e != nil {
return e
}
if e := p.OutputObjectBegin(); e != nil {
return e
}
s, e1 := p.TypeIdToString(typeId)
if e1 != nil {
return e1
}
if e := p.WriteString(s); e != nil {
return e
}
return nil
}
func (p *TJSONProtocol) WriteFieldEnd() error {
return p.OutputObjectEnd()
}
func (p *TJSONProtocol) WriteFieldStop() error { return nil }
func (p *TJSONProtocol) WriteMapBegin(keyType TType, valueType TType, size int) error {
if e := p.OutputListBegin(); e != nil {
return e
}
s, e1 := p.TypeIdToString(keyType)
if e1 != nil {
return e1
}
if e := p.WriteString(s); e != nil {
return e
}
s, e1 = p.TypeIdToString(valueType)
if e1 != nil {
return e1
}
if e := p.WriteString(s); e != nil {
return e
}
if e := p.WriteI64(int64(size)); e != nil {
return e
}
return p.OutputObjectBegin()
}
func (p *TJSONProtocol) WriteMapEnd() error {
if e := p.OutputObjectEnd(); e != nil {
return e
}
return p.OutputListEnd()
}
func (p *TJSONProtocol) WriteListBegin(elemType TType, size int) error {
return p.OutputElemListBegin(elemType, size)
}
func (p *TJSONProtocol) WriteListEnd() error {
return p.OutputListEnd()
}
func (p *TJSONProtocol) WriteSetBegin(elemType TType, size int) error {
return p.OutputElemListBegin(elemType, size)
}
func (p *TJSONProtocol) WriteSetEnd() error {
return p.OutputListEnd()
}
func (p *TJSONProtocol) WriteBool(b bool) error {
if b {
return p.WriteI32(1)
}
return p.WriteI32(0)
}
func (p *TJSONProtocol) WriteByte(b int8) error {
return p.WriteI32(int32(b))
}
func (p *TJSONProtocol) WriteI16(v int16) error {
return p.WriteI32(int32(v))
}
func (p *TJSONProtocol) WriteI32(v int32) error {
return p.OutputI64(int64(v))
}
func (p *TJSONProtocol) WriteI64(v int64) error {
return p.OutputI64(int64(v))
}
func (p *TJSONProtocol) WriteDouble(v float64) error {
return p.OutputF64(v)
}
func (p *TJSONProtocol) WriteString(v string) error {
return p.OutputString(v)
}
func (p *TJSONProtocol) WriteBinary(v []byte) error {
// JSON library only takes in a string,
// not an arbitrary byte array, to ensure bytes are transmitted
// efficiently we must convert this into a valid JSON string
// therefore we use base64 encoding to avoid excessive escaping/quoting
if e := p.OutputPreValue(); e != nil {
return e
}
if _, e := p.write(JSON_QUOTE_BYTES); e != nil {
return NewTProtocolException(e)
}
writer := base64.NewEncoder(base64.StdEncoding, p.writer)
if _, e := writer.Write(v); e != nil {
p.writer.Reset(p.trans) // THRIFT-3735
return NewTProtocolException(e)
}
if e := writer.Close(); e != nil {
return NewTProtocolException(e)
}
if _, e := p.write(JSON_QUOTE_BYTES); e != nil {
return NewTProtocolException(e)
}
return p.OutputPostValue()
}
// Reading methods.
func (p *TJSONProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqId int32, err error) {
p.resetContextStack() // THRIFT-3735
if isNull, err := p.ParseListBegin(); isNull || err != nil {
return name, typeId, seqId, err
}
version, err := p.ReadI32()
if err != nil {
return name, typeId, seqId, err
}
if version != THRIFT_JSON_PROTOCOL_VERSION {
e := fmt.Errorf("Unknown Protocol version %d, expected version %d", version, THRIFT_JSON_PROTOCOL_VERSION)
return name, typeId, seqId, NewTProtocolExceptionWithType(INVALID_DATA, e)
}
if name, err = p.ReadString(); err != nil {
return name, typeId, seqId, err
}
bTypeId, err := p.ReadByte()
typeId = TMessageType(bTypeId)
if err != nil {
return name, typeId, seqId, err
}
if seqId, err = p.ReadI32(); err != nil {
return name, typeId, seqId, err
}
return name, typeId, seqId, nil
}
func (p *TJSONProtocol) ReadMessageEnd() error {
err := p.ParseListEnd()
return err
}
func (p *TJSONProtocol) ReadStructBegin() (name string, err error) {
_, err = p.ParseObjectStart()
return "", err
}
func (p *TJSONProtocol) ReadStructEnd() error {
return p.ParseObjectEnd()
}
func (p *TJSONProtocol) ReadFieldBegin() (string, TType, int16, error) {
b, _ := p.reader.Peek(1)
if len(b) < 1 || b[0] == JSON_RBRACE[0] || b[0] == JSON_RBRACKET[0] {
return "", STOP, -1, nil
}
fieldId, err := p.ReadI16()
if err != nil {
return "", STOP, fieldId, err
}
if _, err = p.ParseObjectStart(); err != nil {
return "", STOP, fieldId, err
}
sType, err := p.ReadString()
if err != nil {
return "", STOP, fieldId, err
}
fType, err := p.StringToTypeId(sType)
return "", fType, fieldId, err
}
func (p *TJSONProtocol) ReadFieldEnd() error {
return p.ParseObjectEnd()
}
func (p *TJSONProtocol) ReadMapBegin() (keyType TType, valueType TType, size int, e error) {
if isNull, e := p.ParseListBegin(); isNull || e != nil {
return VOID, VOID, 0, e
}
// read keyType
sKeyType, e := p.ReadString()
if e != nil {
return keyType, valueType, size, e
}
keyType, e = p.StringToTypeId(sKeyType)
if e != nil {
return keyType, valueType, size, e
}
// read valueType
sValueType, e := p.ReadString()
if e != nil {
return keyType, valueType, size, e
}
valueType, e = p.StringToTypeId(sValueType)
if e != nil {
return keyType, valueType, size, e
}
// read size
iSize, e := p.ReadI64()
if e != nil {
return keyType, valueType, size, e
}
size = int(iSize)
_, e = p.ParseObjectStart()
return keyType, valueType, size, e
}
func (p *TJSONProtocol) ReadMapEnd() error {
e := p.ParseObjectEnd()
if e != nil {
return e
}
return p.ParseListEnd()
}
func (p *TJSONProtocol) ReadListBegin() (elemType TType, size int, e error) {
return p.ParseElemListBegin()
}
func (p *TJSONProtocol) ReadListEnd() error {
return p.ParseListEnd()
}
func (p *TJSONProtocol) ReadSetBegin() (elemType TType, size int, e error) {
return p.ParseElemListBegin()
}
func (p *TJSONProtocol) ReadSetEnd() error {
return p.ParseListEnd()
}
func (p *TJSONProtocol) ReadBool() (bool, error) {
value, err := p.ReadI32()
return (value != 0), err
}
func (p *TJSONProtocol) ReadByte() (int8, error) {
v, err := p.ReadI64()
return int8(v), err
}
func (p *TJSONProtocol) ReadI16() (int16, error) {
v, err := p.ReadI64()
return int16(v), err
}
func (p *TJSONProtocol) ReadI32() (int32, error) {
v, err := p.ReadI64()
return int32(v), err
}
func (p *TJSONProtocol) ReadI64() (int64, error) {
v, _, err := p.ParseI64()
return v, err
}
func (p *TJSONProtocol) ReadDouble() (float64, error) {
v, _, err := p.ParseF64()
return v, err
}
func (p *TJSONProtocol) ReadString() (string, error) {
var v string
if err := p.ParsePreValue(); err != nil {
return v, err
}
f, _ := p.reader.Peek(1)
if len(f) > 0 && f[0] == JSON_QUOTE {
p.reader.ReadByte()
value, err := p.ParseStringBody()
v = value
if err != nil {
return v, err
}
} else if len(f) > 0 && f[0] == JSON_NULL[0] {
b := make([]byte, len(JSON_NULL))
_, err := p.reader.Read(b)
if err != nil {
return v, NewTProtocolException(err)
}
if string(b) != string(JSON_NULL) {
e := fmt.Errorf("Expected a JSON string, found unquoted data started with %s", string(b))
return v, NewTProtocolExceptionWithType(INVALID_DATA, e)
}
} else {
e := fmt.Errorf("Expected a JSON string, found unquoted data started with %s", string(f))
return v, NewTProtocolExceptionWithType(INVALID_DATA, e)
}
return v, p.ParsePostValue()
}
func (p *TJSONProtocol) ReadBinary() ([]byte, error) {
var v []byte
if err := p.ParsePreValue(); err != nil {
return nil, err
}
f, _ := p.reader.Peek(1)
if len(f) > 0 && f[0] == JSON_QUOTE {
p.reader.ReadByte()
value, err := p.ParseBase64EncodedBody()
v = value
if err != nil {
return v, err
}
} else if len(f) > 0 && f[0] == JSON_NULL[0] {
b := make([]byte, len(JSON_NULL))
_, err := p.reader.Read(b)
if err != nil {
return v, NewTProtocolException(err)
}
if string(b) != string(JSON_NULL) {
e := fmt.Errorf("Expected a JSON string, found unquoted data started with %s", string(b))
return v, NewTProtocolExceptionWithType(INVALID_DATA, e)
}
} else {
e := fmt.Errorf("Expected a JSON string, found unquoted data started with %s", string(f))
return v, NewTProtocolExceptionWithType(INVALID_DATA, e)
}
return v, p.ParsePostValue()
}
func (p *TJSONProtocol) Flush(ctx context.Context) (err error) {
err = p.writer.Flush()
if err == nil {
err = p.trans.Flush(ctx)
}
return NewTProtocolException(err)
}
func (p *TJSONProtocol) Skip(fieldType TType) (err error) {
return SkipDefaultDepth(p, fieldType)
}
func (p *TJSONProtocol) Transport() TTransport {
return p.trans
}
func (p *TJSONProtocol) OutputElemListBegin(elemType TType, size int) error {
if e := p.OutputListBegin(); e != nil {
return e
}
s, e1 := p.TypeIdToString(elemType)
if e1 != nil {
return e1
}
if e := p.WriteString(s); e != nil {
return e
}
if e := p.WriteI64(int64(size)); e != nil {
return e
}
return nil
}
func (p *TJSONProtocol) ParseElemListBegin() (elemType TType, size int, e error) {
if isNull, e := p.ParseListBegin(); isNull || e != nil {
return VOID, 0, e
}
sElemType, err := p.ReadString()
if err != nil {
return VOID, size, err
}
elemType, err = p.StringToTypeId(sElemType)
if err != nil {
return elemType, size, err
}
nSize, err2 := p.ReadI64()
size = int(nSize)
return elemType, size, err2
}
func (p *TJSONProtocol) readElemListBegin() (elemType TType, size int, e error) {
if isNull, e := p.ParseListBegin(); isNull || e != nil {
return VOID, 0, e
}
sElemType, err := p.ReadString()
if err != nil {
return VOID, size, err
}
elemType, err = p.StringToTypeId(sElemType)
if err != nil {
return elemType, size, err
}
nSize, err2 := p.ReadI64()
size = int(nSize)
return elemType, size, err2
}
func (p *TJSONProtocol) writeElemListBegin(elemType TType, size int) error {
if e := p.OutputListBegin(); e != nil {
return e
}
s, e1 := p.TypeIdToString(elemType)
if e1 != nil {
return e1
}
if e := p.OutputString(s); e != nil {
return e
}
if e := p.OutputI64(int64(size)); e != nil {
return e
}
return nil
}
func (p *TJSONProtocol) TypeIdToString(fieldType TType) (string, error) {
switch byte(fieldType) {
case BOOL:
return "tf", nil
case BYTE:
return "i8", nil
case I16:
return "i16", nil
case I32:
return "i32", nil
case I64:
return "i64", nil
case DOUBLE:
return "dbl", nil
case STRING:
return "str", nil
case STRUCT:
return "rec", nil
case MAP:
return "map", nil
case SET:
return "set", nil
case LIST:
return "lst", nil
}
e := fmt.Errorf("Unknown fieldType: %d", int(fieldType))
return "", NewTProtocolExceptionWithType(INVALID_DATA, e)
}
func (p *TJSONProtocol) StringToTypeId(fieldType string) (TType, error) {
switch fieldType {
case "tf":
return TType(BOOL), nil
case "i8":
return TType(BYTE), nil
case "i16":
return TType(I16), nil
case "i32":
return TType(I32), nil
case "i64":
return TType(I64), nil
case "dbl":
return TType(DOUBLE), nil
case "str":
return TType(STRING), nil
case "rec":
return TType(STRUCT), nil
case "map":
return TType(MAP), nil
case "set":
return TType(SET), nil
case "lst":
return TType(LIST), nil
}
e := fmt.Errorf("Unknown type identifier: %s", fieldType)
return TType(STOP), NewTProtocolExceptionWithType(INVALID_DATA, e)
}

View file

@ -1,650 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"math"
"strconv"
"testing"
)
func TestWriteJSONProtocolBool(t *testing.T) {
thetype := "boolean"
trans := NewTMemoryBuffer()
p := NewTJSONProtocol(trans)
for _, value := range BOOL_VALUES {
if e := p.WriteBool(value); e != nil {
t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error())
}
if e := p.Flush(context.Background()); e != nil {
t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error())
}
s := trans.String()
expected := ""
if value {
expected = "1"
} else {
expected = "0"
}
if s != expected {
t.Fatalf("Bad value for %s %v: %s expected", thetype, value, s)
}
v := -1
if err := json.Unmarshal([]byte(s), &v); err != nil || (v != 0) != value {
t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v)
}
trans.Reset()
}
trans.Close()
}
func TestReadJSONProtocolBool(t *testing.T) {
thetype := "boolean"
for _, value := range BOOL_VALUES {
trans := NewTMemoryBuffer()
p := NewTJSONProtocol(trans)
if value {
trans.Write([]byte{'1'}) // not JSON_TRUE
} else {
trans.Write([]byte{'0'}) // not JSON_FALSE
}
trans.Flush(context.Background())
s := trans.String()
v, e := p.ReadBool()
if e != nil {
t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error())
}
if v != value {
t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v)
}
vv := -1
if err := json.Unmarshal([]byte(s), &vv); err != nil || (vv != 0) != value {
t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, vv)
}
trans.Reset()
trans.Close()
}
}
func TestWriteJSONProtocolByte(t *testing.T) {
thetype := "byte"
trans := NewTMemoryBuffer()
p := NewTJSONProtocol(trans)
for _, value := range BYTE_VALUES {
if e := p.WriteByte(value); e != nil {
t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error())
}
if e := p.Flush(context.Background()); e != nil {
t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error())
}
s := trans.String()
if s != fmt.Sprint(value) {
t.Fatalf("Bad value for %s %v: %s", thetype, value, s)
}
v := int8(0)
if err := json.Unmarshal([]byte(s), &v); err != nil || v != value {
t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v)
}
trans.Reset()
}
trans.Close()
}
func TestReadJSONProtocolByte(t *testing.T) {
thetype := "byte"
for _, value := range BYTE_VALUES {
trans := NewTMemoryBuffer()
p := NewTJSONProtocol(trans)
trans.WriteString(strconv.Itoa(int(value)))
trans.Flush(context.Background())
s := trans.String()
v, e := p.ReadByte()
if e != nil {
t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error())
}
if v != value {
t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v)
}
if err := json.Unmarshal([]byte(s), &v); err != nil || v != value {
t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v)
}
trans.Reset()
trans.Close()
}
}
func TestWriteJSONProtocolI16(t *testing.T) {
thetype := "int16"
trans := NewTMemoryBuffer()
p := NewTJSONProtocol(trans)
for _, value := range INT16_VALUES {
if e := p.WriteI16(value); e != nil {
t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error())
}
if e := p.Flush(context.Background()); e != nil {
t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error())
}
s := trans.String()
if s != fmt.Sprint(value) {
t.Fatalf("Bad value for %s %v: %s", thetype, value, s)
}
v := int16(0)
if err := json.Unmarshal([]byte(s), &v); err != nil || v != value {
t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v)
}
trans.Reset()
}
trans.Close()
}
func TestReadJSONProtocolI16(t *testing.T) {
thetype := "int16"
for _, value := range INT16_VALUES {
trans := NewTMemoryBuffer()
p := NewTJSONProtocol(trans)
trans.WriteString(strconv.Itoa(int(value)))
trans.Flush(context.Background())
s := trans.String()
v, e := p.ReadI16()
if e != nil {
t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error())
}
if v != value {
t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v)
}
if err := json.Unmarshal([]byte(s), &v); err != nil || v != value {
t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v)
}
trans.Reset()
trans.Close()
}
}
func TestWriteJSONProtocolI32(t *testing.T) {
thetype := "int32"
trans := NewTMemoryBuffer()
p := NewTJSONProtocol(trans)
for _, value := range INT32_VALUES {
if e := p.WriteI32(value); e != nil {
t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error())
}
if e := p.Flush(context.Background()); e != nil {
t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error())
}
s := trans.String()
if s != fmt.Sprint(value) {
t.Fatalf("Bad value for %s %v: %s", thetype, value, s)
}
v := int32(0)
if err := json.Unmarshal([]byte(s), &v); err != nil || v != value {
t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v)
}
trans.Reset()
}
trans.Close()
}
func TestReadJSONProtocolI32(t *testing.T) {
thetype := "int32"
for _, value := range INT32_VALUES {
trans := NewTMemoryBuffer()
p := NewTJSONProtocol(trans)
trans.WriteString(strconv.Itoa(int(value)))
trans.Flush(context.Background())
s := trans.String()
v, e := p.ReadI32()
if e != nil {
t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error())
}
if v != value {
t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v)
}
if err := json.Unmarshal([]byte(s), &v); err != nil || v != value {
t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v)
}
trans.Reset()
trans.Close()
}
}
func TestWriteJSONProtocolI64(t *testing.T) {
thetype := "int64"
trans := NewTMemoryBuffer()
p := NewTJSONProtocol(trans)
for _, value := range INT64_VALUES {
if e := p.WriteI64(value); e != nil {
t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error())
}
if e := p.Flush(context.Background()); e != nil {
t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error())
}
s := trans.String()
if s != fmt.Sprint(value) {
t.Fatalf("Bad value for %s %v: %s", thetype, value, s)
}
v := int64(0)
if err := json.Unmarshal([]byte(s), &v); err != nil || v != value {
t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v)
}
trans.Reset()
}
trans.Close()
}
func TestReadJSONProtocolI64(t *testing.T) {
thetype := "int64"
for _, value := range INT64_VALUES {
trans := NewTMemoryBuffer()
p := NewTJSONProtocol(trans)
trans.WriteString(strconv.FormatInt(value, 10))
trans.Flush(context.Background())
s := trans.String()
v, e := p.ReadI64()
if e != nil {
t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error())
}
if v != value {
t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v)
}
if err := json.Unmarshal([]byte(s), &v); err != nil || v != value {
t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v)
}
trans.Reset()
trans.Close()
}
}
func TestWriteJSONProtocolDouble(t *testing.T) {
thetype := "double"
trans := NewTMemoryBuffer()
p := NewTJSONProtocol(trans)
for _, value := range DOUBLE_VALUES {
if e := p.WriteDouble(value); e != nil {
t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error())
}
if e := p.Flush(context.Background()); e != nil {
t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error())
}
s := trans.String()
if math.IsInf(value, 1) {
if s != jsonQuote(JSON_INFINITY) {
t.Fatalf("Bad value for %s %v, wrote: %v, expected: %v", thetype, value, s, jsonQuote(JSON_INFINITY))
}
} else if math.IsInf(value, -1) {
if s != jsonQuote(JSON_NEGATIVE_INFINITY) {
t.Fatalf("Bad value for %s %v, wrote: %v, expected: %v", thetype, value, s, jsonQuote(JSON_NEGATIVE_INFINITY))
}
} else if math.IsNaN(value) {
if s != jsonQuote(JSON_NAN) {
t.Fatalf("Bad value for %s %v, wrote: %v, expected: %v", thetype, value, s, jsonQuote(JSON_NAN))
}
} else {
if s != fmt.Sprint(value) {
t.Fatalf("Bad value for %s %v: %s", thetype, value, s)
}
v := float64(0)
if err := json.Unmarshal([]byte(s), &v); err != nil || v != value {
t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v)
}
}
trans.Reset()
}
trans.Close()
}
func TestReadJSONProtocolDouble(t *testing.T) {
thetype := "double"
for _, value := range DOUBLE_VALUES {
trans := NewTMemoryBuffer()
p := NewTJSONProtocol(trans)
n := NewNumericFromDouble(value)
trans.WriteString(n.String())
trans.Flush(context.Background())
s := trans.String()
v, e := p.ReadDouble()
if e != nil {
t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error())
}
if math.IsInf(value, 1) {
if !math.IsInf(v, 1) {
t.Fatalf("Bad value for %s %v, wrote: %v, received: %v", thetype, value, s, v)
}
} else if math.IsInf(value, -1) {
if !math.IsInf(v, -1) {
t.Fatalf("Bad value for %s %v, wrote: %v, received: %v", thetype, value, s, v)
}
} else if math.IsNaN(value) {
if !math.IsNaN(v) {
t.Fatalf("Bad value for %s %v, wrote: %v, received: %v", thetype, value, s, v)
}
} else {
if v != value {
t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v)
}
if err := json.Unmarshal([]byte(s), &v); err != nil || v != value {
t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v)
}
}
trans.Reset()
trans.Close()
}
}
func TestWriteJSONProtocolString(t *testing.T) {
thetype := "string"
trans := NewTMemoryBuffer()
p := NewTJSONProtocol(trans)
for _, value := range STRING_VALUES {
if e := p.WriteString(value); e != nil {
t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error())
}
if e := p.Flush(context.Background()); e != nil {
t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error())
}
s := trans.String()
if s[0] != '"' || s[len(s)-1] != '"' {
t.Fatalf("Bad value for %s '%v', wrote '%v', expected: %v", thetype, value, s, fmt.Sprint("\"", value, "\""))
}
v := new(string)
if err := json.Unmarshal([]byte(s), v); err != nil || *v != value {
t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, *v)
}
trans.Reset()
}
trans.Close()
}
func TestReadJSONProtocolString(t *testing.T) {
thetype := "string"
for _, value := range STRING_VALUES {
trans := NewTMemoryBuffer()
p := NewTJSONProtocol(trans)
trans.WriteString(jsonQuote(value))
trans.Flush(context.Background())
s := trans.String()
v, e := p.ReadString()
if e != nil {
t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error())
}
if v != value {
t.Fatalf("Bad value for %s value %v, wrote: %v, received: %v", thetype, value, s, v)
}
v1 := new(string)
if err := json.Unmarshal([]byte(s), v1); err != nil || *v1 != value {
t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, *v1)
}
trans.Reset()
trans.Close()
}
}
func TestWriteJSONProtocolBinary(t *testing.T) {
thetype := "binary"
value := protocol_bdata
b64value := make([]byte, base64.StdEncoding.EncodedLen(len(protocol_bdata)))
base64.StdEncoding.Encode(b64value, value)
b64String := string(b64value)
trans := NewTMemoryBuffer()
p := NewTJSONProtocol(trans)
if e := p.WriteBinary(value); e != nil {
t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error())
}
if e := p.Flush(context.Background()); e != nil {
t.Fatalf("Unable to write %s value %v due to error flushing: %s", thetype, value, e.Error())
}
s := trans.String()
expectedString := fmt.Sprint("\"", b64String, "\"")
if s != expectedString {
t.Fatalf("Bad value for %s %v\n wrote: \"%v\"\nexpected: \"%v\"", thetype, value, s, expectedString)
}
v1, err := p.ReadBinary()
if err != nil {
t.Fatalf("Unable to read binary: %s", err.Error())
}
if len(v1) != len(value) {
t.Fatalf("Invalid value for binary\nexpected: \"%v\"\n read: \"%v\"", value, v1)
}
for k, v := range value {
if v1[k] != v {
t.Fatalf("Invalid value for binary at %v\nexpected: \"%v\"\n read: \"%v\"", k, v, v1[k])
}
}
trans.Close()
}
func TestReadJSONProtocolBinary(t *testing.T) {
thetype := "binary"
value := protocol_bdata
b64value := make([]byte, base64.StdEncoding.EncodedLen(len(protocol_bdata)))
base64.StdEncoding.Encode(b64value, value)
b64String := string(b64value)
trans := NewTMemoryBuffer()
p := NewTJSONProtocol(trans)
trans.WriteString(jsonQuote(b64String))
trans.Flush(context.Background())
s := trans.String()
v, e := p.ReadBinary()
if e != nil {
t.Fatalf("Unable to read %s value %v due to error: %s", thetype, value, e.Error())
}
if len(v) != len(value) {
t.Fatalf("Bad value for %s value length %v, wrote: %v, received length: %v", thetype, len(value), s, len(v))
}
for i := 0; i < len(v); i++ {
if v[i] != value[i] {
t.Fatalf("Bad value for %s at index %d value %v, wrote: %v, received: %v", thetype, i, value[i], s, v[i])
}
}
v1 := new(string)
if err := json.Unmarshal([]byte(s), v1); err != nil || *v1 != b64String {
t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, *v1)
}
trans.Reset()
trans.Close()
}
func TestWriteJSONProtocolList(t *testing.T) {
thetype := "list"
trans := NewTMemoryBuffer()
p := NewTJSONProtocol(trans)
p.WriteListBegin(TType(DOUBLE), len(DOUBLE_VALUES))
for _, value := range DOUBLE_VALUES {
if e := p.WriteDouble(value); e != nil {
t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error())
}
}
p.WriteListEnd()
if e := p.Flush(context.Background()); e != nil {
t.Fatalf("Unable to write %s due to error flushing: %s", thetype, e.Error())
}
str := trans.String()
str1 := new([]interface{})
err := json.Unmarshal([]byte(str), str1)
if err != nil {
t.Fatalf("Unable to decode %s, wrote: %s", thetype, str)
}
l := *str1
if len(l) < 2 {
t.Fatalf("List must be at least of length two to include metadata")
}
if l[0] != "dbl" {
t.Fatal("Invalid type for list, expected: ", STRING, ", but was: ", l[0])
}
if int(l[1].(float64)) != len(DOUBLE_VALUES) {
t.Fatal("Invalid length for list, expected: ", len(DOUBLE_VALUES), ", but was: ", l[1])
}
for k, value := range DOUBLE_VALUES {
s := l[k+2]
if math.IsInf(value, 1) {
if s.(string) != JSON_INFINITY {
t.Fatalf("Bad value for %s at index %v %v, wrote: %q, expected: %q, originally wrote: %q", thetype, k, value, s, jsonQuote(JSON_INFINITY), str)
}
} else if math.IsInf(value, 0) {
if s.(string) != JSON_NEGATIVE_INFINITY {
t.Fatalf("Bad value for %s at index %v %v, wrote: %q, expected: %q, originally wrote: %q", thetype, k, value, s, jsonQuote(JSON_NEGATIVE_INFINITY), str)
}
} else if math.IsNaN(value) {
if s.(string) != JSON_NAN {
t.Fatalf("Bad value for %s at index %v %v, wrote: %q, expected: %q, originally wrote: %q", thetype, k, value, s, jsonQuote(JSON_NAN), str)
}
} else {
if s.(float64) != value {
t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s'", thetype, value, s)
}
}
trans.Reset()
}
trans.Close()
}
func TestWriteJSONProtocolSet(t *testing.T) {
thetype := "set"
trans := NewTMemoryBuffer()
p := NewTJSONProtocol(trans)
p.WriteSetBegin(TType(DOUBLE), len(DOUBLE_VALUES))
for _, value := range DOUBLE_VALUES {
if e := p.WriteDouble(value); e != nil {
t.Fatalf("Unable to write %s value %v due to error: %s", thetype, value, e.Error())
}
}
p.WriteSetEnd()
if e := p.Flush(context.Background()); e != nil {
t.Fatalf("Unable to write %s due to error flushing: %s", thetype, e.Error())
}
str := trans.String()
str1 := new([]interface{})
err := json.Unmarshal([]byte(str), str1)
if err != nil {
t.Fatalf("Unable to decode %s, wrote: %s", thetype, str)
}
l := *str1
if len(l) < 2 {
t.Fatalf("Set must be at least of length two to include metadata")
}
if l[0] != "dbl" {
t.Fatal("Invalid type for set, expected: ", DOUBLE, ", but was: ", l[0])
}
if int(l[1].(float64)) != len(DOUBLE_VALUES) {
t.Fatal("Invalid length for set, expected: ", len(DOUBLE_VALUES), ", but was: ", l[1])
}
for k, value := range DOUBLE_VALUES {
s := l[k+2]
if math.IsInf(value, 1) {
if s.(string) != JSON_INFINITY {
t.Fatalf("Bad value for %s at index %v %v, wrote: %q, expected: %q, originally wrote: %q", thetype, k, value, s, jsonQuote(JSON_INFINITY), str)
}
} else if math.IsInf(value, 0) {
if s.(string) != JSON_NEGATIVE_INFINITY {
t.Fatalf("Bad value for %s at index %v %v, wrote: %q, expected: %q, originally wrote: %q", thetype, k, value, s, jsonQuote(JSON_NEGATIVE_INFINITY), str)
}
} else if math.IsNaN(value) {
if s.(string) != JSON_NAN {
t.Fatalf("Bad value for %s at index %v %v, wrote: %q, expected: %q, originally wrote: %q", thetype, k, value, s, jsonQuote(JSON_NAN), str)
}
} else {
if s.(float64) != value {
t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s'", thetype, value, s)
}
}
trans.Reset()
}
trans.Close()
}
func TestWriteJSONProtocolMap(t *testing.T) {
thetype := "map"
trans := NewTMemoryBuffer()
p := NewTJSONProtocol(trans)
p.WriteMapBegin(TType(I32), TType(DOUBLE), len(DOUBLE_VALUES))
for k, value := range DOUBLE_VALUES {
if e := p.WriteI32(int32(k)); e != nil {
t.Fatalf("Unable to write %s key int32 value %v due to error: %s", thetype, k, e.Error())
}
if e := p.WriteDouble(value); e != nil {
t.Fatalf("Unable to write %s value float64 value %v due to error: %s", thetype, value, e.Error())
}
}
p.WriteMapEnd()
if e := p.Flush(context.Background()); e != nil {
t.Fatalf("Unable to write %s due to error flushing: %s", thetype, e.Error())
}
str := trans.String()
if str[0] != '[' || str[len(str)-1] != ']' {
t.Fatalf("Bad value for %s, wrote: %v, in go: %v", thetype, str, DOUBLE_VALUES)
}
expectedKeyType, expectedValueType, expectedSize, err := p.ReadMapBegin()
if err != nil {
t.Fatalf("Error while reading map begin: %s", err.Error())
}
if expectedKeyType != I32 {
t.Fatal("Expected map key type ", I32, ", but was ", expectedKeyType)
}
if expectedValueType != DOUBLE {
t.Fatal("Expected map value type ", DOUBLE, ", but was ", expectedValueType)
}
if expectedSize != len(DOUBLE_VALUES) {
t.Fatal("Expected map size of ", len(DOUBLE_VALUES), ", but was ", expectedSize)
}
for k, value := range DOUBLE_VALUES {
ik, err := p.ReadI32()
if err != nil {
t.Fatalf("Bad key for %s index %v, wrote: %v, expected: %v, error: %s", thetype, k, ik, string(k), err.Error())
}
if int(ik) != k {
t.Fatalf("Bad key for %s index %v, wrote: %v, expected: %v", thetype, k, ik, k)
}
dv, err := p.ReadDouble()
if err != nil {
t.Fatalf("Bad value for %s index %v, wrote: %v, expected: %v, error: %s", thetype, k, dv, value, err.Error())
}
s := strconv.FormatFloat(dv, 'g', 10, 64)
if math.IsInf(value, 1) {
if !math.IsInf(dv, 1) {
t.Fatalf("Bad value for %s at index %v %v, wrote: %v, expected: %v", thetype, k, value, s, jsonQuote(JSON_INFINITY))
}
} else if math.IsInf(value, 0) {
if !math.IsInf(dv, 0) {
t.Fatalf("Bad value for %s at index %v %v, wrote: %v, expected: %v", thetype, k, value, s, jsonQuote(JSON_NEGATIVE_INFINITY))
}
} else if math.IsNaN(value) {
if !math.IsNaN(dv) {
t.Fatalf("Bad value for %s at index %v %v, wrote: %v, expected: %v", thetype, k, value, s, jsonQuote(JSON_NAN))
}
} else {
expected := strconv.FormatFloat(value, 'g', 10, 64)
if s != expected {
t.Fatalf("Bad value for %s at index %v %v, wrote: %v, expected %v", thetype, k, value, s, expected)
}
v := float64(0)
if err := json.Unmarshal([]byte(s), &v); err != nil || v != value {
t.Fatalf("Bad json-decoded value for %s %v, wrote: '%s', expected: '%v'", thetype, value, s, v)
}
}
}
err = p.ReadMapEnd()
if err != nil {
t.Fatalf("Error while reading map end: %s", err.Error())
}
trans.Close()
}

View file

@ -1,540 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"bytes"
"testing"
)
var binaryProtoF = NewTBinaryProtocolFactoryDefault()
var compactProtoF = NewTCompactProtocolFactory()
var buf = bytes.NewBuffer(make([]byte, 0, 1024))
var tfv = []TTransportFactory{
NewTMemoryBufferTransportFactory(1024),
NewStreamTransportFactory(buf, buf, true),
NewTFramedTransportFactory(NewTMemoryBufferTransportFactory(1024)),
}
func BenchmarkBinaryBool_0(b *testing.B) {
trans, err := tfv[0].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := binaryProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteBool(b, p, trans)
}
}
func BenchmarkBinaryByte_0(b *testing.B) {
trans, err := tfv[0].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := binaryProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteByte(b, p, trans)
}
}
func BenchmarkBinaryI16_0(b *testing.B) {
trans, err := tfv[0].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := binaryProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteI16(b, p, trans)
}
}
func BenchmarkBinaryI32_0(b *testing.B) {
trans, err := tfv[0].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := binaryProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteI32(b, p, trans)
}
}
func BenchmarkBinaryI64_0(b *testing.B) {
trans, err := tfv[0].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := binaryProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteI64(b, p, trans)
}
}
func BenchmarkBinaryDouble_0(b *testing.B) {
trans, err := tfv[0].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := binaryProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteDouble(b, p, trans)
}
}
func BenchmarkBinaryString_0(b *testing.B) {
trans, err := tfv[0].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := binaryProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteString(b, p, trans)
}
}
func BenchmarkBinaryBinary_0(b *testing.B) {
trans, err := tfv[0].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := binaryProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteBinary(b, p, trans)
}
}
func BenchmarkBinaryBool_1(b *testing.B) {
trans, err := tfv[1].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := binaryProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteBool(b, p, trans)
}
}
func BenchmarkBinaryByte_1(b *testing.B) {
trans, err := tfv[1].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := binaryProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteByte(b, p, trans)
}
}
func BenchmarkBinaryI16_1(b *testing.B) {
trans, err := tfv[1].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := binaryProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteI16(b, p, trans)
}
}
func BenchmarkBinaryI32_1(b *testing.B) {
trans, err := tfv[1].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := binaryProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteI32(b, p, trans)
}
}
func BenchmarkBinaryI64_1(b *testing.B) {
trans, err := tfv[1].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := binaryProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteI64(b, p, trans)
}
}
func BenchmarkBinaryDouble_1(b *testing.B) {
trans, err := tfv[1].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := binaryProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteDouble(b, p, trans)
}
}
func BenchmarkBinaryString_1(b *testing.B) {
trans, err := tfv[1].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := binaryProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteString(b, p, trans)
}
}
func BenchmarkBinaryBinary_1(b *testing.B) {
trans, err := tfv[1].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := binaryProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteBinary(b, p, trans)
}
}
func BenchmarkBinaryBool_2(b *testing.B) {
trans, err := tfv[2].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := binaryProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteBool(b, p, trans)
}
}
func BenchmarkBinaryByte_2(b *testing.B) {
trans, err := tfv[2].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := binaryProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteByte(b, p, trans)
}
}
func BenchmarkBinaryI16_2(b *testing.B) {
trans, err := tfv[2].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := binaryProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteI16(b, p, trans)
}
}
func BenchmarkBinaryI32_2(b *testing.B) {
trans, err := tfv[2].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := binaryProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteI32(b, p, trans)
}
}
func BenchmarkBinaryI64_2(b *testing.B) {
trans, err := tfv[2].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := binaryProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteI64(b, p, trans)
}
}
func BenchmarkBinaryDouble_2(b *testing.B) {
trans, err := tfv[2].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := binaryProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteDouble(b, p, trans)
}
}
func BenchmarkBinaryString_2(b *testing.B) {
trans, err := tfv[2].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := binaryProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteString(b, p, trans)
}
}
func BenchmarkBinaryBinary_2(b *testing.B) {
trans, err := tfv[2].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := binaryProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteBinary(b, p, trans)
}
}
func BenchmarkCompactBool_0(b *testing.B) {
trans, err := tfv[0].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := compactProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteBool(b, p, trans)
}
}
func BenchmarkCompactByte_0(b *testing.B) {
trans, err := tfv[0].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := compactProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteByte(b, p, trans)
}
}
func BenchmarkCompactI16_0(b *testing.B) {
trans, err := tfv[0].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := compactProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteI16(b, p, trans)
}
}
func BenchmarkCompactI32_0(b *testing.B) {
trans, err := tfv[0].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := compactProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteI32(b, p, trans)
}
}
func BenchmarkCompactI64_0(b *testing.B) {
trans, err := tfv[0].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := compactProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteI64(b, p, trans)
}
}
func BenchmarkCompactDouble0(b *testing.B) {
trans, err := tfv[0].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := compactProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteDouble(b, p, trans)
}
}
func BenchmarkCompactString0(b *testing.B) {
trans, err := tfv[0].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := compactProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteString(b, p, trans)
}
}
func BenchmarkCompactBinary0(b *testing.B) {
trans, err := tfv[0].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := compactProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteBinary(b, p, trans)
}
}
func BenchmarkCompactBool_1(b *testing.B) {
trans, err := tfv[1].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := compactProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteBool(b, p, trans)
}
}
func BenchmarkCompactByte_1(b *testing.B) {
trans, err := tfv[1].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := compactProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteByte(b, p, trans)
}
}
func BenchmarkCompactI16_1(b *testing.B) {
trans, err := tfv[1].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := compactProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteI16(b, p, trans)
}
}
func BenchmarkCompactI32_1(b *testing.B) {
trans, err := tfv[1].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := compactProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteI32(b, p, trans)
}
}
func BenchmarkCompactI64_1(b *testing.B) {
trans, err := tfv[1].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := compactProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteI64(b, p, trans)
}
}
func BenchmarkCompactDouble1(b *testing.B) {
trans, err := tfv[1].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := compactProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteDouble(b, p, trans)
}
}
func BenchmarkCompactString1(b *testing.B) {
trans, err := tfv[1].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := compactProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteString(b, p, trans)
}
}
func BenchmarkCompactBinary1(b *testing.B) {
trans, err := tfv[1].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := compactProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteBinary(b, p, trans)
}
}
func BenchmarkCompactBool_2(b *testing.B) {
trans, err := tfv[2].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := compactProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteBool(b, p, trans)
}
}
func BenchmarkCompactByte_2(b *testing.B) {
trans, err := tfv[2].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := compactProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteByte(b, p, trans)
}
}
func BenchmarkCompactI16_2(b *testing.B) {
trans, err := tfv[2].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := compactProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteI16(b, p, trans)
}
}
func BenchmarkCompactI32_2(b *testing.B) {
trans, err := tfv[2].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := compactProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteI32(b, p, trans)
}
}
func BenchmarkCompactI64_2(b *testing.B) {
trans, err := tfv[2].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := compactProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteI64(b, p, trans)
}
}
func BenchmarkCompactDouble2(b *testing.B) {
trans, err := tfv[2].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := compactProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteDouble(b, p, trans)
}
}
func BenchmarkCompactString2(b *testing.B) {
trans, err := tfv[2].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := compactProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteString(b, p, trans)
}
}
func BenchmarkCompactBinary2(b *testing.B) {
trans, err := tfv[2].GetTransport(nil)
if err != nil {
b.Fatal(err)
}
p := compactProtoF.GetProtocol(trans)
for i := 0; i < b.N; i++ {
ReadWriteBinary(b, p, trans)
}
}

View file

@ -1,80 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"bytes"
"context"
)
// Memory buffer-based implementation of the TTransport interface.
type TMemoryBuffer struct {
*bytes.Buffer
size int
}
type TMemoryBufferTransportFactory struct {
size int
}
func (p *TMemoryBufferTransportFactory) GetTransport(trans TTransport) (TTransport, error) {
if trans != nil {
t, ok := trans.(*TMemoryBuffer)
if ok && t.size > 0 {
return NewTMemoryBufferLen(t.size), nil
}
}
return NewTMemoryBufferLen(p.size), nil
}
func NewTMemoryBufferTransportFactory(size int) *TMemoryBufferTransportFactory {
return &TMemoryBufferTransportFactory{size: size}
}
func NewTMemoryBuffer() *TMemoryBuffer {
return &TMemoryBuffer{Buffer: &bytes.Buffer{}, size: 0}
}
func NewTMemoryBufferLen(size int) *TMemoryBuffer {
buf := make([]byte, 0, size)
return &TMemoryBuffer{Buffer: bytes.NewBuffer(buf), size: size}
}
func (p *TMemoryBuffer) IsOpen() bool {
return true
}
func (p *TMemoryBuffer) Open() error {
return nil
}
func (p *TMemoryBuffer) Close() error {
p.Buffer.Reset()
return nil
}
// Flushing a memory buffer is a no-op
func (p *TMemoryBuffer) Flush(ctx context.Context) error {
return nil
}
func (p *TMemoryBuffer) RemainingBytes() (num_bytes uint64) {
return uint64(p.Buffer.Len())
}

View file

@ -1,29 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"testing"
)
func TestMemoryBuffer(t *testing.T) {
trans := NewTMemoryBufferLen(1024)
TransportTest(t, trans, trans)
}

View file

@ -1,31 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
// Message type constants in the Thrift protocol.
type TMessageType int32
const (
INVALID_TMESSAGE_TYPE TMessageType = 0
CALL TMessageType = 1
REPLY TMessageType = 2
EXCEPTION TMessageType = 3
ONEWAY TMessageType = 4
)

View file

@ -1,170 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"context"
"fmt"
"strings"
)
/*
TMultiplexedProtocol is a protocol-independent concrete decorator
that allows a Thrift client to communicate with a multiplexing Thrift server,
by prepending the service name to the function name during function calls.
NOTE: THIS IS NOT USED BY SERVERS. On the server, use TMultiplexedProcessor to handle request
from a multiplexing client.
This example uses a single socket transport to invoke two services:
socket := thrift.NewTSocketFromAddrTimeout(addr, TIMEOUT)
transport := thrift.NewTFramedTransport(socket)
protocol := thrift.NewTBinaryProtocolTransport(transport)
mp := thrift.NewTMultiplexedProtocol(protocol, "Calculator")
service := Calculator.NewCalculatorClient(mp)
mp2 := thrift.NewTMultiplexedProtocol(protocol, "WeatherReport")
service2 := WeatherReport.NewWeatherReportClient(mp2)
err := transport.Open()
if err != nil {
t.Fatal("Unable to open client socket", err)
}
fmt.Println(service.Add(2,2))
fmt.Println(service2.GetTemperature())
*/
type TMultiplexedProtocol struct {
TProtocol
serviceName string
}
const MULTIPLEXED_SEPARATOR = ":"
func NewTMultiplexedProtocol(protocol TProtocol, serviceName string) *TMultiplexedProtocol {
return &TMultiplexedProtocol{
TProtocol: protocol,
serviceName: serviceName,
}
}
func (t *TMultiplexedProtocol) WriteMessageBegin(name string, typeId TMessageType, seqid int32) error {
if typeId == CALL || typeId == ONEWAY {
return t.TProtocol.WriteMessageBegin(t.serviceName+MULTIPLEXED_SEPARATOR+name, typeId, seqid)
} else {
return t.TProtocol.WriteMessageBegin(name, typeId, seqid)
}
}
/*
TMultiplexedProcessor is a TProcessor allowing
a single TServer to provide multiple services.
To do so, you instantiate the processor and then register additional
processors with it, as shown in the following example:
var processor = thrift.NewTMultiplexedProcessor()
firstProcessor :=
processor.RegisterProcessor("FirstService", firstProcessor)
processor.registerProcessor(
"Calculator",
Calculator.NewCalculatorProcessor(&CalculatorHandler{}),
)
processor.registerProcessor(
"WeatherReport",
WeatherReport.NewWeatherReportProcessor(&WeatherReportHandler{}),
)
serverTransport, err := thrift.NewTServerSocketTimeout(addr, TIMEOUT)
if err != nil {
t.Fatal("Unable to create server socket", err)
}
server := thrift.NewTSimpleServer2(processor, serverTransport)
server.Serve();
*/
type TMultiplexedProcessor struct {
serviceProcessorMap map[string]TProcessor
DefaultProcessor TProcessor
}
func NewTMultiplexedProcessor() *TMultiplexedProcessor {
return &TMultiplexedProcessor{
serviceProcessorMap: make(map[string]TProcessor),
}
}
func (t *TMultiplexedProcessor) RegisterDefault(processor TProcessor) {
t.DefaultProcessor = processor
}
func (t *TMultiplexedProcessor) RegisterProcessor(name string, processor TProcessor) {
if t.serviceProcessorMap == nil {
t.serviceProcessorMap = make(map[string]TProcessor)
}
t.serviceProcessorMap[name] = processor
}
func (t *TMultiplexedProcessor) Process(ctx context.Context, in, out TProtocol) (bool, TException) {
name, typeId, seqid, err := in.ReadMessageBegin()
if err != nil {
return false, err
}
if typeId != CALL && typeId != ONEWAY {
return false, fmt.Errorf("Unexpected message type %v", typeId)
}
//extract the service name
v := strings.SplitN(name, MULTIPLEXED_SEPARATOR, 2)
if len(v) != 2 {
if t.DefaultProcessor != nil {
smb := NewStoredMessageProtocol(in, name, typeId, seqid)
return t.DefaultProcessor.Process(ctx, smb, out)
}
return false, fmt.Errorf("Service name not found in message name: %s. Did you forget to use a TMultiplexProtocol in your client?", name)
}
actualProcessor, ok := t.serviceProcessorMap[v[0]]
if !ok {
return false, fmt.Errorf("Service name not found: %s. Did you forget to call registerProcessor()?", v[0])
}
smb := NewStoredMessageProtocol(in, v[1], typeId, seqid)
return actualProcessor.Process(ctx, smb, out)
}
//Protocol that use stored message for ReadMessageBegin
type storedMessageProtocol struct {
TProtocol
name string
typeId TMessageType
seqid int32
}
func NewStoredMessageProtocol(protocol TProtocol, name string, typeId TMessageType, seqid int32) *storedMessageProtocol {
return &storedMessageProtocol{protocol, name, typeId, seqid}
}
func (s *storedMessageProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqid int32, err error) {
return s.name, s.typeId, s.seqid, nil
}

View file

@ -1,164 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"math"
"strconv"
)
type Numeric interface {
Int64() int64
Int32() int32
Int16() int16
Byte() byte
Int() int
Float64() float64
Float32() float32
String() string
isNull() bool
}
type numeric struct {
iValue int64
dValue float64
sValue string
isNil bool
}
var (
INFINITY Numeric
NEGATIVE_INFINITY Numeric
NAN Numeric
ZERO Numeric
NUMERIC_NULL Numeric
)
func NewNumericFromDouble(dValue float64) Numeric {
if math.IsInf(dValue, 1) {
return INFINITY
}
if math.IsInf(dValue, -1) {
return NEGATIVE_INFINITY
}
if math.IsNaN(dValue) {
return NAN
}
iValue := int64(dValue)
sValue := strconv.FormatFloat(dValue, 'g', 10, 64)
isNil := false
return &numeric{iValue: iValue, dValue: dValue, sValue: sValue, isNil: isNil}
}
func NewNumericFromI64(iValue int64) Numeric {
dValue := float64(iValue)
sValue := string(iValue)
isNil := false
return &numeric{iValue: iValue, dValue: dValue, sValue: sValue, isNil: isNil}
}
func NewNumericFromI32(iValue int32) Numeric {
dValue := float64(iValue)
sValue := string(iValue)
isNil := false
return &numeric{iValue: int64(iValue), dValue: dValue, sValue: sValue, isNil: isNil}
}
func NewNumericFromString(sValue string) Numeric {
if sValue == INFINITY.String() {
return INFINITY
}
if sValue == NEGATIVE_INFINITY.String() {
return NEGATIVE_INFINITY
}
if sValue == NAN.String() {
return NAN
}
iValue, _ := strconv.ParseInt(sValue, 10, 64)
dValue, _ := strconv.ParseFloat(sValue, 64)
isNil := len(sValue) == 0
return &numeric{iValue: iValue, dValue: dValue, sValue: sValue, isNil: isNil}
}
func NewNumericFromJSONString(sValue string, isNull bool) Numeric {
if isNull {
return NewNullNumeric()
}
if sValue == JSON_INFINITY {
return INFINITY
}
if sValue == JSON_NEGATIVE_INFINITY {
return NEGATIVE_INFINITY
}
if sValue == JSON_NAN {
return NAN
}
iValue, _ := strconv.ParseInt(sValue, 10, 64)
dValue, _ := strconv.ParseFloat(sValue, 64)
return &numeric{iValue: iValue, dValue: dValue, sValue: sValue, isNil: isNull}
}
func NewNullNumeric() Numeric {
return &numeric{iValue: 0, dValue: 0.0, sValue: "", isNil: true}
}
func (p *numeric) Int64() int64 {
return p.iValue
}
func (p *numeric) Int32() int32 {
return int32(p.iValue)
}
func (p *numeric) Int16() int16 {
return int16(p.iValue)
}
func (p *numeric) Byte() byte {
return byte(p.iValue)
}
func (p *numeric) Int() int {
return int(p.iValue)
}
func (p *numeric) Float64() float64 {
return p.dValue
}
func (p *numeric) Float32() float32 {
return float32(p.dValue)
}
func (p *numeric) String() string {
return p.sValue
}
func (p *numeric) isNull() bool {
return p.isNil
}
func init() {
INFINITY = &numeric{iValue: 0, dValue: math.Inf(1), sValue: "Infinity", isNil: false}
NEGATIVE_INFINITY = &numeric{iValue: 0, dValue: math.Inf(-1), sValue: "-Infinity", isNil: false}
NAN = &numeric{iValue: 0, dValue: math.NaN(), sValue: "NaN", isNil: false}
ZERO = &numeric{iValue: 0, dValue: 0, sValue: "0", isNil: false}
NUMERIC_NULL = &numeric{iValue: 0, dValue: 0, sValue: "0", isNil: true}
}

View file

@ -1,50 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
///////////////////////////////////////////////////////////////////////////////
// This file is home to helpers that convert from various base types to
// respective pointer types. This is necessary because Go does not permit
// references to constants, nor can a pointer type to base type be allocated
// and initialized in a single expression.
//
// E.g., this is not allowed:
//
// var ip *int = &5
//
// But this *is* allowed:
//
// func IntPtr(i int) *int { return &i }
// var ip *int = IntPtr(5)
//
// Since pointers to base types are commonplace as [optional] fields in
// exported thrift structs, we factor such helpers here.
///////////////////////////////////////////////////////////////////////////////
func Float32Ptr(v float32) *float32 { return &v }
func Float64Ptr(v float64) *float64 { return &v }
func IntPtr(v int) *int { return &v }
func Int32Ptr(v int32) *int32 { return &v }
func Int64Ptr(v int64) *int64 { return &v }
func StringPtr(v string) *string { return &v }
func Uint32Ptr(v uint32) *uint32 { return &v }
func Uint64Ptr(v uint64) *uint64 { return &v }
func BoolPtr(v bool) *bool { return &v }
func ByteSlicePtr(v []byte) *[]byte { return &v }

View file

@ -1,70 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import "context"
// A processor is a generic object which operates upon an input stream and
// writes to some output stream.
type TProcessor interface {
Process(ctx context.Context, in, out TProtocol) (bool, TException)
}
type TProcessorFunction interface {
Process(ctx context.Context, seqId int32, in, out TProtocol) (bool, TException)
}
// The default processor factory just returns a singleton
// instance.
type TProcessorFactory interface {
GetProcessor(trans TTransport) TProcessor
}
type tProcessorFactory struct {
processor TProcessor
}
func NewTProcessorFactory(p TProcessor) TProcessorFactory {
return &tProcessorFactory{processor: p}
}
func (p *tProcessorFactory) GetProcessor(trans TTransport) TProcessor {
return p.processor
}
/**
* The default processor factory just returns a singleton
* instance.
*/
type TProcessorFunctionFactory interface {
GetProcessorFunction(trans TTransport) TProcessorFunction
}
type tProcessorFunctionFactory struct {
processor TProcessorFunction
}
func NewTProcessorFunctionFactory(p TProcessorFunction) TProcessorFunctionFactory {
return &tProcessorFunctionFactory{processor: p}
}
func (p *tProcessorFunctionFactory) GetProcessorFunction(trans TTransport) TProcessorFunction {
return p.processor
}

View file

@ -1,179 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"context"
"errors"
"fmt"
)
const (
VERSION_MASK = 0xffff0000
VERSION_1 = 0x80010000
)
type TProtocol interface {
WriteMessageBegin(name string, typeId TMessageType, seqid int32) error
WriteMessageEnd() error
WriteStructBegin(name string) error
WriteStructEnd() error
WriteFieldBegin(name string, typeId TType, id int16) error
WriteFieldEnd() error
WriteFieldStop() error
WriteMapBegin(keyType TType, valueType TType, size int) error
WriteMapEnd() error
WriteListBegin(elemType TType, size int) error
WriteListEnd() error
WriteSetBegin(elemType TType, size int) error
WriteSetEnd() error
WriteBool(value bool) error
WriteByte(value int8) error
WriteI16(value int16) error
WriteI32(value int32) error
WriteI64(value int64) error
WriteDouble(value float64) error
WriteString(value string) error
WriteBinary(value []byte) error
ReadMessageBegin() (name string, typeId TMessageType, seqid int32, err error)
ReadMessageEnd() error
ReadStructBegin() (name string, err error)
ReadStructEnd() error
ReadFieldBegin() (name string, typeId TType, id int16, err error)
ReadFieldEnd() error
ReadMapBegin() (keyType TType, valueType TType, size int, err error)
ReadMapEnd() error
ReadListBegin() (elemType TType, size int, err error)
ReadListEnd() error
ReadSetBegin() (elemType TType, size int, err error)
ReadSetEnd() error
ReadBool() (value bool, err error)
ReadByte() (value int8, err error)
ReadI16() (value int16, err error)
ReadI32() (value int32, err error)
ReadI64() (value int64, err error)
ReadDouble() (value float64, err error)
ReadString() (value string, err error)
ReadBinary() (value []byte, err error)
Skip(fieldType TType) (err error)
Flush(ctx context.Context) (err error)
Transport() TTransport
}
// The maximum recursive depth the skip() function will traverse
const DEFAULT_RECURSION_DEPTH = 64
// Skips over the next data element from the provided input TProtocol object.
func SkipDefaultDepth(prot TProtocol, typeId TType) (err error) {
return Skip(prot, typeId, DEFAULT_RECURSION_DEPTH)
}
// Skips over the next data element from the provided input TProtocol object.
func Skip(self TProtocol, fieldType TType, maxDepth int) (err error) {
if maxDepth <= 0 {
return NewTProtocolExceptionWithType(DEPTH_LIMIT, errors.New("Depth limit exceeded"))
}
switch fieldType {
case STOP:
return
case BOOL:
_, err = self.ReadBool()
return
case BYTE:
_, err = self.ReadByte()
return
case I16:
_, err = self.ReadI16()
return
case I32:
_, err = self.ReadI32()
return
case I64:
_, err = self.ReadI64()
return
case DOUBLE:
_, err = self.ReadDouble()
return
case STRING:
_, err = self.ReadString()
return
case STRUCT:
if _, err = self.ReadStructBegin(); err != nil {
return err
}
for {
_, typeId, _, _ := self.ReadFieldBegin()
if typeId == STOP {
break
}
err := Skip(self, typeId, maxDepth-1)
if err != nil {
return err
}
self.ReadFieldEnd()
}
return self.ReadStructEnd()
case MAP:
keyType, valueType, size, err := self.ReadMapBegin()
if err != nil {
return err
}
for i := 0; i < size; i++ {
err := Skip(self, keyType, maxDepth-1)
if err != nil {
return err
}
self.Skip(valueType)
}
return self.ReadMapEnd()
case SET:
elemType, size, err := self.ReadSetBegin()
if err != nil {
return err
}
for i := 0; i < size; i++ {
err := Skip(self, elemType, maxDepth-1)
if err != nil {
return err
}
}
return self.ReadSetEnd()
case LIST:
elemType, size, err := self.ReadListBegin()
if err != nil {
return err
}
for i := 0; i < size; i++ {
err := Skip(self, elemType, maxDepth-1)
if err != nil {
return err
}
}
return self.ReadListEnd()
default:
return NewTProtocolExceptionWithType(INVALID_DATA, errors.New(fmt.Sprintf("Unknown data type %d", fieldType)))
}
return nil
}

View file

@ -1,77 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"encoding/base64"
)
// Thrift Protocol exception
type TProtocolException interface {
TException
TypeId() int
}
const (
UNKNOWN_PROTOCOL_EXCEPTION = 0
INVALID_DATA = 1
NEGATIVE_SIZE = 2
SIZE_LIMIT = 3
BAD_VERSION = 4
NOT_IMPLEMENTED = 5
DEPTH_LIMIT = 6
)
type tProtocolException struct {
typeId int
message string
}
func (p *tProtocolException) TypeId() int {
return p.typeId
}
func (p *tProtocolException) String() string {
return p.message
}
func (p *tProtocolException) Error() string {
return p.message
}
func NewTProtocolException(err error) TProtocolException {
if err == nil {
return nil
}
if e, ok := err.(TProtocolException); ok {
return e
}
if _, ok := err.(base64.CorruptInputError); ok {
return &tProtocolException{INVALID_DATA, err.Error()}
}
return &tProtocolException{UNKNOWN_PROTOCOL_EXCEPTION, err.Error()}
}
func NewTProtocolExceptionWithType(errType int, err error) TProtocolException {
if err == nil {
return nil
}
return &tProtocolException{errType, err.Error()}
}

View file

@ -1,25 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
// Factory interface for constructing protocol instances.
type TProtocolFactory interface {
GetProtocol(trans TTransport) TProtocol
}

View file

@ -1,517 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"bytes"
"context"
"io/ioutil"
"math"
"net"
"net/http"
"testing"
)
const PROTOCOL_BINARY_DATA_SIZE = 155
var (
protocol_bdata []byte // test data for writing; same as data
BOOL_VALUES []bool
BYTE_VALUES []int8
INT16_VALUES []int16
INT32_VALUES []int32
INT64_VALUES []int64
DOUBLE_VALUES []float64
STRING_VALUES []string
)
func init() {
protocol_bdata = make([]byte, PROTOCOL_BINARY_DATA_SIZE)
for i := 0; i < PROTOCOL_BINARY_DATA_SIZE; i++ {
protocol_bdata[i] = byte((i + 'a') % 255)
}
BOOL_VALUES = []bool{false, true, false, false, true}
BYTE_VALUES = []int8{117, 0, 1, 32, 127, -128, -1}
INT16_VALUES = []int16{459, 0, 1, -1, -128, 127, 32767, -32768}
INT32_VALUES = []int32{459, 0, 1, -1, -128, 127, 32767, 2147483647, -2147483535}
INT64_VALUES = []int64{459, 0, 1, -1, -128, 127, 32767, 2147483647, -2147483535, 34359738481, -35184372088719, -9223372036854775808, 9223372036854775807}
DOUBLE_VALUES = []float64{459.3, 0.0, -1.0, 1.0, 0.5, 0.3333, 3.14159, 1.537e-38, 1.673e25, 6.02214179e23, -6.02214179e23, INFINITY.Float64(), NEGATIVE_INFINITY.Float64(), NAN.Float64()}
STRING_VALUES = []string{"", "a", "st[uf]f", "st,u:ff with spaces", "stuff\twith\nescape\\characters'...\"lots{of}fun</xml>"}
}
type HTTPEchoServer struct{}
type HTTPHeaderEchoServer struct{}
func (p *HTTPEchoServer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
buf, err := ioutil.ReadAll(req.Body)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write(buf)
} else {
w.WriteHeader(http.StatusOK)
w.Write(buf)
}
}
func (p *HTTPHeaderEchoServer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
buf, err := ioutil.ReadAll(req.Body)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write(buf)
} else {
w.WriteHeader(http.StatusOK)
w.Write(buf)
}
}
func HttpClientSetupForTest(t *testing.T) (net.Listener, net.Addr) {
addr, err := FindAvailableTCPServerPort(40000)
if err != nil {
t.Fatalf("Unable to find available tcp port addr: %s", err)
return nil, addr
}
l, err := net.Listen(addr.Network(), addr.String())
if err != nil {
t.Fatalf("Unable to setup tcp listener on %s: %s", addr.String(), err)
return l, addr
}
go http.Serve(l, &HTTPEchoServer{})
return l, addr
}
func HttpClientSetupForHeaderTest(t *testing.T) (net.Listener, net.Addr) {
addr, err := FindAvailableTCPServerPort(40000)
if err != nil {
t.Fatalf("Unable to find available tcp port addr: %s", err)
return nil, addr
}
l, err := net.Listen(addr.Network(), addr.String())
if err != nil {
t.Fatalf("Unable to setup tcp listener on %s: %s", addr.String(), err)
return l, addr
}
go http.Serve(l, &HTTPHeaderEchoServer{})
return l, addr
}
func ReadWriteProtocolTest(t *testing.T, protocolFactory TProtocolFactory) {
buf := bytes.NewBuffer(make([]byte, 0, 1024))
l, addr := HttpClientSetupForTest(t)
defer l.Close()
transports := []TTransportFactory{
NewTMemoryBufferTransportFactory(1024),
NewStreamTransportFactory(buf, buf, true),
NewTFramedTransportFactory(NewTMemoryBufferTransportFactory(1024)),
NewTZlibTransportFactoryWithFactory(0, NewTMemoryBufferTransportFactory(1024)),
NewTZlibTransportFactoryWithFactory(6, NewTMemoryBufferTransportFactory(1024)),
NewTZlibTransportFactoryWithFactory(9, NewTFramedTransportFactory(NewTMemoryBufferTransportFactory(1024))),
NewTHttpPostClientTransportFactory("http://" + addr.String()),
}
for _, tf := range transports {
trans, err := tf.GetTransport(nil)
if err != nil {
t.Error(err)
continue
}
p := protocolFactory.GetProtocol(trans)
ReadWriteBool(t, p, trans)
trans.Close()
}
for _, tf := range transports {
trans, err := tf.GetTransport(nil)
if err != nil {
t.Error(err)
continue
}
p := protocolFactory.GetProtocol(trans)
ReadWriteByte(t, p, trans)
trans.Close()
}
for _, tf := range transports {
trans, err := tf.GetTransport(nil)
if err != nil {
t.Error(err)
continue
}
p := protocolFactory.GetProtocol(trans)
ReadWriteI16(t, p, trans)
trans.Close()
}
for _, tf := range transports {
trans, err := tf.GetTransport(nil)
if err != nil {
t.Error(err)
continue
}
p := protocolFactory.GetProtocol(trans)
ReadWriteI32(t, p, trans)
trans.Close()
}
for _, tf := range transports {
trans, err := tf.GetTransport(nil)
if err != nil {
t.Error(err)
continue
}
p := protocolFactory.GetProtocol(trans)
ReadWriteI64(t, p, trans)
trans.Close()
}
for _, tf := range transports {
trans, err := tf.GetTransport(nil)
if err != nil {
t.Error(err)
continue
}
p := protocolFactory.GetProtocol(trans)
ReadWriteDouble(t, p, trans)
trans.Close()
}
for _, tf := range transports {
trans, err := tf.GetTransport(nil)
if err != nil {
t.Error(err)
continue
}
p := protocolFactory.GetProtocol(trans)
ReadWriteString(t, p, trans)
trans.Close()
}
for _, tf := range transports {
trans, err := tf.GetTransport(nil)
if err != nil {
t.Error(err)
continue
}
p := protocolFactory.GetProtocol(trans)
ReadWriteBinary(t, p, trans)
trans.Close()
}
for _, tf := range transports {
trans, err := tf.GetTransport(nil)
if err != nil {
t.Error(err)
continue
}
p := protocolFactory.GetProtocol(trans)
ReadWriteI64(t, p, trans)
ReadWriteDouble(t, p, trans)
ReadWriteBinary(t, p, trans)
ReadWriteByte(t, p, trans)
trans.Close()
}
}
func ReadWriteBool(t testing.TB, p TProtocol, trans TTransport) {
thetype := TType(BOOL)
thelen := len(BOOL_VALUES)
err := p.WriteListBegin(thetype, thelen)
if err != nil {
t.Errorf("%s: %T %T %q Error writing list begin: %q", "ReadWriteBool", p, trans, err, thetype)
}
for k, v := range BOOL_VALUES {
err = p.WriteBool(v)
if err != nil {
t.Errorf("%s: %T %T %v Error writing bool in list at index %v: %v", "ReadWriteBool", p, trans, err, k, v)
}
}
p.WriteListEnd()
if err != nil {
t.Errorf("%s: %T %T %v Error writing list end: %v", "ReadWriteBool", p, trans, err, BOOL_VALUES)
}
p.Flush(context.Background())
thetype2, thelen2, err := p.ReadListBegin()
if err != nil {
t.Errorf("%s: %T %T %v Error reading list: %v", "ReadWriteBool", p, trans, err, BOOL_VALUES)
}
_, ok := p.(*TSimpleJSONProtocol)
if !ok {
if thetype != thetype2 {
t.Errorf("%s: %T %T type %s != type %s", "ReadWriteBool", p, trans, thetype, thetype2)
}
if thelen != thelen2 {
t.Errorf("%s: %T %T len %v != len %v", "ReadWriteBool", p, trans, thelen, thelen2)
}
}
for k, v := range BOOL_VALUES {
value, err := p.ReadBool()
if err != nil {
t.Errorf("%s: %T %T %v Error reading bool at index %v: %v", "ReadWriteBool", p, trans, err, k, v)
}
if v != value {
t.Errorf("%s: index %v %v %v %v != %v", "ReadWriteBool", k, p, trans, v, value)
}
}
err = p.ReadListEnd()
if err != nil {
t.Errorf("%s: %T %T Unable to read list end: %q", "ReadWriteBool", p, trans, err)
}
}
func ReadWriteByte(t testing.TB, p TProtocol, trans TTransport) {
thetype := TType(BYTE)
thelen := len(BYTE_VALUES)
err := p.WriteListBegin(thetype, thelen)
if err != nil {
t.Errorf("%s: %T %T %q Error writing list begin: %q", "ReadWriteByte", p, trans, err, thetype)
}
for k, v := range BYTE_VALUES {
err = p.WriteByte(v)
if err != nil {
t.Errorf("%s: %T %T %q Error writing byte in list at index %d: %q", "ReadWriteByte", p, trans, err, k, v)
}
}
err = p.WriteListEnd()
if err != nil {
t.Errorf("%s: %T %T %q Error writing list end: %q", "ReadWriteByte", p, trans, err, BYTE_VALUES)
}
err = p.Flush(context.Background())
if err != nil {
t.Errorf("%s: %T %T %q Error flushing list of bytes: %q", "ReadWriteByte", p, trans, err, BYTE_VALUES)
}
thetype2, thelen2, err := p.ReadListBegin()
if err != nil {
t.Errorf("%s: %T %T %q Error reading list: %q", "ReadWriteByte", p, trans, err, BYTE_VALUES)
}
_, ok := p.(*TSimpleJSONProtocol)
if !ok {
if thetype != thetype2 {
t.Errorf("%s: %T %T type %s != type %s", "ReadWriteByte", p, trans, thetype, thetype2)
}
if thelen != thelen2 {
t.Errorf("%s: %T %T len %v != len %v", "ReadWriteByte", p, trans, thelen, thelen2)
}
}
for k, v := range BYTE_VALUES {
value, err := p.ReadByte()
if err != nil {
t.Errorf("%s: %T %T %q Error reading byte at index %d: %q", "ReadWriteByte", p, trans, err, k, v)
}
if v != value {
t.Errorf("%s: %T %T %d != %d", "ReadWriteByte", p, trans, v, value)
}
}
err = p.ReadListEnd()
if err != nil {
t.Errorf("%s: %T %T Unable to read list end: %q", "ReadWriteByte", p, trans, err)
}
}
func ReadWriteI16(t testing.TB, p TProtocol, trans TTransport) {
thetype := TType(I16)
thelen := len(INT16_VALUES)
p.WriteListBegin(thetype, thelen)
for _, v := range INT16_VALUES {
p.WriteI16(v)
}
p.WriteListEnd()
p.Flush(context.Background())
thetype2, thelen2, err := p.ReadListBegin()
if err != nil {
t.Errorf("%s: %T %T %q Error reading list: %q", "ReadWriteI16", p, trans, err, INT16_VALUES)
}
_, ok := p.(*TSimpleJSONProtocol)
if !ok {
if thetype != thetype2 {
t.Errorf("%s: %T %T type %s != type %s", "ReadWriteI16", p, trans, thetype, thetype2)
}
if thelen != thelen2 {
t.Errorf("%s: %T %T len %v != len %v", "ReadWriteI16", p, trans, thelen, thelen2)
}
}
for k, v := range INT16_VALUES {
value, err := p.ReadI16()
if err != nil {
t.Errorf("%s: %T %T %q Error reading int16 at index %d: %q", "ReadWriteI16", p, trans, err, k, v)
}
if v != value {
t.Errorf("%s: %T %T %d != %d", "ReadWriteI16", p, trans, v, value)
}
}
err = p.ReadListEnd()
if err != nil {
t.Errorf("%s: %T %T Unable to read list end: %q", "ReadWriteI16", p, trans, err)
}
}
func ReadWriteI32(t testing.TB, p TProtocol, trans TTransport) {
thetype := TType(I32)
thelen := len(INT32_VALUES)
p.WriteListBegin(thetype, thelen)
for _, v := range INT32_VALUES {
p.WriteI32(v)
}
p.WriteListEnd()
p.Flush(context.Background())
thetype2, thelen2, err := p.ReadListBegin()
if err != nil {
t.Errorf("%s: %T %T %q Error reading list: %q", "ReadWriteI32", p, trans, err, INT32_VALUES)
}
_, ok := p.(*TSimpleJSONProtocol)
if !ok {
if thetype != thetype2 {
t.Errorf("%s: %T %T type %s != type %s", "ReadWriteI32", p, trans, thetype, thetype2)
}
if thelen != thelen2 {
t.Errorf("%s: %T %T len %v != len %v", "ReadWriteI32", p, trans, thelen, thelen2)
}
}
for k, v := range INT32_VALUES {
value, err := p.ReadI32()
if err != nil {
t.Errorf("%s: %T %T %q Error reading int32 at index %d: %q", "ReadWriteI32", p, trans, err, k, v)
}
if v != value {
t.Errorf("%s: %T %T %d != %d", "ReadWriteI32", p, trans, v, value)
}
}
if err != nil {
t.Errorf("%s: %T %T Unable to read list end: %q", "ReadWriteI32", p, trans, err)
}
}
func ReadWriteI64(t testing.TB, p TProtocol, trans TTransport) {
thetype := TType(I64)
thelen := len(INT64_VALUES)
p.WriteListBegin(thetype, thelen)
for _, v := range INT64_VALUES {
p.WriteI64(v)
}
p.WriteListEnd()
p.Flush(context.Background())
thetype2, thelen2, err := p.ReadListBegin()
if err != nil {
t.Errorf("%s: %T %T %q Error reading list: %q", "ReadWriteI64", p, trans, err, INT64_VALUES)
}
_, ok := p.(*TSimpleJSONProtocol)
if !ok {
if thetype != thetype2 {
t.Errorf("%s: %T %T type %s != type %s", "ReadWriteI64", p, trans, thetype, thetype2)
}
if thelen != thelen2 {
t.Errorf("%s: %T %T len %v != len %v", "ReadWriteI64", p, trans, thelen, thelen2)
}
}
for k, v := range INT64_VALUES {
value, err := p.ReadI64()
if err != nil {
t.Errorf("%s: %T %T %q Error reading int64 at index %d: %q", "ReadWriteI64", p, trans, err, k, v)
}
if v != value {
t.Errorf("%s: %T %T %q != %q", "ReadWriteI64", p, trans, v, value)
}
}
if err != nil {
t.Errorf("%s: %T %T Unable to read list end: %q", "ReadWriteI64", p, trans, err)
}
}
func ReadWriteDouble(t testing.TB, p TProtocol, trans TTransport) {
thetype := TType(DOUBLE)
thelen := len(DOUBLE_VALUES)
p.WriteListBegin(thetype, thelen)
for _, v := range DOUBLE_VALUES {
p.WriteDouble(v)
}
p.WriteListEnd()
p.Flush(context.Background())
thetype2, thelen2, err := p.ReadListBegin()
if err != nil {
t.Errorf("%s: %T %T %v Error reading list: %v", "ReadWriteDouble", p, trans, err, DOUBLE_VALUES)
}
if thetype != thetype2 {
t.Errorf("%s: %T %T type %s != type %s", "ReadWriteDouble", p, trans, thetype, thetype2)
}
if thelen != thelen2 {
t.Errorf("%s: %T %T len %v != len %v", "ReadWriteDouble", p, trans, thelen, thelen2)
}
for k, v := range DOUBLE_VALUES {
value, err := p.ReadDouble()
if err != nil {
t.Errorf("%s: %T %T %q Error reading double at index %d: %v", "ReadWriteDouble", p, trans, err, k, v)
}
if math.IsNaN(v) {
if !math.IsNaN(value) {
t.Errorf("%s: %T %T math.IsNaN(%v) != math.IsNaN(%v)", "ReadWriteDouble", p, trans, v, value)
}
} else if v != value {
t.Errorf("%s: %T %T %v != %v", "ReadWriteDouble", p, trans, v, value)
}
}
err = p.ReadListEnd()
if err != nil {
t.Errorf("%s: %T %T Unable to read list end: %q", "ReadWriteDouble", p, trans, err)
}
}
func ReadWriteString(t testing.TB, p TProtocol, trans TTransport) {
thetype := TType(STRING)
thelen := len(STRING_VALUES)
p.WriteListBegin(thetype, thelen)
for _, v := range STRING_VALUES {
p.WriteString(v)
}
p.WriteListEnd()
p.Flush(context.Background())
thetype2, thelen2, err := p.ReadListBegin()
if err != nil {
t.Errorf("%s: %T %T %q Error reading list: %q", "ReadWriteString", p, trans, err, STRING_VALUES)
}
_, ok := p.(*TSimpleJSONProtocol)
if !ok {
if thetype != thetype2 {
t.Errorf("%s: %T %T type %s != type %s", "ReadWriteString", p, trans, thetype, thetype2)
}
if thelen != thelen2 {
t.Errorf("%s: %T %T len %v != len %v", "ReadWriteString", p, trans, thelen, thelen2)
}
}
for k, v := range STRING_VALUES {
value, err := p.ReadString()
if err != nil {
t.Errorf("%s: %T %T %q Error reading string at index %d: %q", "ReadWriteString", p, trans, err, k, v)
}
if v != value {
t.Errorf("%s: %T %T %v != %v", "ReadWriteString", p, trans, v, value)
}
}
if err != nil {
t.Errorf("%s: %T %T Unable to read list end: %q", "ReadWriteString", p, trans, err)
}
}
func ReadWriteBinary(t testing.TB, p TProtocol, trans TTransport) {
v := protocol_bdata
p.WriteBinary(v)
p.Flush(context.Background())
value, err := p.ReadBinary()
if err != nil {
t.Errorf("%s: %T %T Unable to read binary: %s", "ReadWriteBinary", p, trans, err.Error())
}
if len(v) != len(value) {
t.Errorf("%s: %T %T len(v) != len(value)... %d != %d", "ReadWriteBinary", p, trans, len(v), len(value))
} else {
for i := 0; i < len(v); i++ {
if v[i] != value[i] {
t.Errorf("%s: %T %T %s != %s", "ReadWriteBinary", p, trans, v, value)
}
}
}
}

View file

@ -1,68 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import "io"
type RichTransport struct {
TTransport
}
// Wraps Transport to provide TRichTransport interface
func NewTRichTransport(trans TTransport) *RichTransport {
return &RichTransport{trans}
}
func (r *RichTransport) ReadByte() (c byte, err error) {
return readByte(r.TTransport)
}
func (r *RichTransport) WriteByte(c byte) error {
return writeByte(r.TTransport, c)
}
func (r *RichTransport) WriteString(s string) (n int, err error) {
return r.Write([]byte(s))
}
func (r *RichTransport) RemainingBytes() (num_bytes uint64) {
return r.TTransport.RemainingBytes()
}
func readByte(r io.Reader) (c byte, err error) {
v := [1]byte{0}
n, err := r.Read(v[0:1])
if n > 0 && (err == nil || err == io.EOF) {
return v[0], nil
}
if n > 0 && err != nil {
return v[0], err
}
if err != nil {
return 0, err
}
return v[0], nil
}
func writeByte(w io.Writer, c byte) error {
v := [1]byte{c}
_, err := w.Write(v[0:1])
return err
}

View file

@ -1,89 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"bytes"
"errors"
"io"
"reflect"
"testing"
)
func TestEnsureTransportsAreRich(t *testing.T) {
buf := bytes.NewBuffer(make([]byte, 0, 1024))
transports := []TTransportFactory{
NewTMemoryBufferTransportFactory(1024),
NewStreamTransportFactory(buf, buf, true),
NewTFramedTransportFactory(NewTMemoryBufferTransportFactory(1024)),
NewTHttpPostClientTransportFactory("http://127.0.0.1"),
}
for _, tf := range transports {
trans, err := tf.GetTransport(nil)
if err != nil {
t.Error(err)
continue
}
_, ok := trans.(TRichTransport)
if !ok {
t.Errorf("Transport %s does not implement TRichTransport interface", reflect.ValueOf(trans))
}
}
}
// TestReadByte tests whether readByte handles error cases correctly.
func TestReadByte(t *testing.T) {
for i, test := range readByteTests {
v, err := readByte(test.r)
if v != test.v {
t.Fatalf("TestReadByte %d: value differs. Expected %d, got %d", i, test.v, test.r.v)
}
if err != test.err {
t.Fatalf("TestReadByte %d: error differs. Expected %s, got %s", i, test.err, test.r.err)
}
}
}
var someError = errors.New("Some error")
var readByteTests = []struct {
r *mockReader
v byte
err error
}{
{&mockReader{0, 55, io.EOF}, 0, io.EOF}, // reader sends EOF w/o data
{&mockReader{0, 55, someError}, 0, someError}, // reader sends some other error
{&mockReader{1, 55, nil}, 55, nil}, // reader sends data w/o error
{&mockReader{1, 55, io.EOF}, 55, nil}, // reader sends data with EOF
{&mockReader{1, 55, someError}, 55, someError}, // reader sends data withsome error
}
type mockReader struct {
n int
v byte
err error
}
func (r *mockReader) Read(p []byte) (n int, err error) {
if r.n > 0 {
p[0] = r.v
}
return r.n, r.err
}

View file

@ -1,79 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"context"
)
type TSerializer struct {
Transport *TMemoryBuffer
Protocol TProtocol
}
type TStruct interface {
Write(p TProtocol) error
Read(p TProtocol) error
}
func NewTSerializer() *TSerializer {
transport := NewTMemoryBufferLen(1024)
protocol := NewTBinaryProtocolFactoryDefault().GetProtocol(transport)
return &TSerializer{
transport,
protocol}
}
func (t *TSerializer) WriteString(ctx context.Context, msg TStruct) (s string, err error) {
t.Transport.Reset()
if err = msg.Write(t.Protocol); err != nil {
return
}
if err = t.Protocol.Flush(ctx); err != nil {
return
}
if err = t.Transport.Flush(ctx); err != nil {
return
}
return t.Transport.String(), nil
}
func (t *TSerializer) Write(ctx context.Context, msg TStruct) (b []byte, err error) {
t.Transport.Reset()
if err = msg.Write(t.Protocol); err != nil {
return
}
if err = t.Protocol.Flush(ctx); err != nil {
return
}
if err = t.Transport.Flush(ctx); err != nil {
return
}
b = append(b, t.Transport.Bytes()...)
return
}

View file

@ -1,170 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"context"
"errors"
"fmt"
"testing"
)
type ProtocolFactory interface {
GetProtocol(t TTransport) TProtocol
}
func compareStructs(m, m1 MyTestStruct) (bool, error) {
switch {
case m.On != m1.On:
return false, errors.New("Boolean not equal")
case m.B != m1.B:
return false, errors.New("Byte not equal")
case m.Int16 != m1.Int16:
return false, errors.New("Int16 not equal")
case m.Int32 != m1.Int32:
return false, errors.New("Int32 not equal")
case m.Int64 != m1.Int64:
return false, errors.New("Int64 not equal")
case m.D != m1.D:
return false, errors.New("Double not equal")
case m.St != m1.St:
return false, errors.New("String not equal")
case len(m.Bin) != len(m1.Bin):
return false, errors.New("Binary size not equal")
case len(m.Bin) == len(m1.Bin):
for i := range m.Bin {
if m.Bin[i] != m1.Bin[i] {
return false, errors.New("Binary not equal")
}
}
case len(m.StringMap) != len(m1.StringMap):
return false, errors.New("StringMap size not equal")
case len(m.StringList) != len(m1.StringList):
return false, errors.New("StringList size not equal")
case len(m.StringSet) != len(m1.StringSet):
return false, errors.New("StringSet size not equal")
case m.E != m1.E:
return false, errors.New("MyTestEnum not equal")
default:
return true, nil
}
return true, nil
}
func ProtocolTest1(test *testing.T, pf ProtocolFactory) (bool, error) {
t := NewTSerializer()
t.Protocol = pf.GetProtocol(t.Transport)
var m = MyTestStruct{}
m.On = true
m.B = int8(0)
m.Int16 = 1
m.Int32 = 2
m.Int64 = 3
m.D = 4.1
m.St = "Test"
m.Bin = make([]byte, 10)
m.StringMap = make(map[string]string, 5)
m.StringList = make([]string, 5)
m.StringSet = make(map[string]struct{}, 5)
m.E = 2
s, err := t.WriteString(context.Background(), &m)
if err != nil {
return false, errors.New(fmt.Sprintf("Unable to Serialize struct\n\t %s", err))
}
t1 := NewTDeserializer()
t1.Protocol = pf.GetProtocol(t1.Transport)
var m1 = MyTestStruct{}
if err = t1.ReadString(&m1, s); err != nil {
return false, errors.New(fmt.Sprintf("Unable to Deserialize struct\n\t %s", err))
}
return compareStructs(m, m1)
}
func ProtocolTest2(test *testing.T, pf ProtocolFactory) (bool, error) {
t := NewTSerializer()
t.Protocol = pf.GetProtocol(t.Transport)
var m = MyTestStruct{}
m.On = false
m.B = int8(0)
m.Int16 = 1
m.Int32 = 2
m.Int64 = 3
m.D = 4.1
m.St = "Test"
m.Bin = make([]byte, 10)
m.StringMap = make(map[string]string, 5)
m.StringList = make([]string, 5)
m.StringSet = make(map[string]struct{}, 5)
m.E = 2
s, err := t.WriteString(context.Background(), &m)
if err != nil {
return false, errors.New(fmt.Sprintf("Unable to Serialize struct\n\t %s", err))
}
t1 := NewTDeserializer()
t1.Protocol = pf.GetProtocol(t1.Transport)
var m1 = MyTestStruct{}
if err = t1.ReadString(&m1, s); err != nil {
return false, errors.New(fmt.Sprintf("Unable to Deserialize struct\n\t %s", err))
}
return compareStructs(m, m1)
}
func TestSerializer(t *testing.T) {
var protocol_factories map[string]ProtocolFactory
protocol_factories = make(map[string]ProtocolFactory)
protocol_factories["Binary"] = NewTBinaryProtocolFactoryDefault()
protocol_factories["Compact"] = NewTCompactProtocolFactory()
//protocol_factories["SimpleJSON"] = NewTSimpleJSONProtocolFactory() - write only, can't be read back by design
protocol_factories["JSON"] = NewTJSONProtocolFactory()
var tests map[string]func(*testing.T, ProtocolFactory) (bool, error)
tests = make(map[string]func(*testing.T, ProtocolFactory) (bool, error))
tests["Test 1"] = ProtocolTest1
tests["Test 2"] = ProtocolTest2
//tests["Test 3"] = ProtocolTest3 // Example of how to add additional tests
for name, pf := range protocol_factories {
for test, f := range tests {
if s, err := f(t, pf); !s || err != nil {
t.Errorf("%s Failed for %s protocol\n\t %s", test, name, err)
}
}
}
}

View file

@ -1,633 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
// Autogenerated by Thrift Compiler (0.12.0)
// DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING
/* THE FOLLOWING THRIFT FILE WAS USED TO CREATE THIS
enum MyTestEnum {
FIRST = 1,
SECOND = 2,
THIRD = 3,
FOURTH = 4,
}
struct MyTestStruct {
1: bool on,
2: byte b,
3: i16 int16,
4: i32 int32,
5: i64 int64,
6: double d,
7: string st,
8: binary bin,
9: map<string, string> stringMap,
10: list<string> stringList,
11: set<string> stringSet,
12: MyTestEnum e,
}
*/
import (
"fmt"
)
// (needed to ensure safety because of naive import list construction.)
var _ = ZERO
var _ = fmt.Printf
var GoUnusedProtection__ int
type MyTestEnum int64
const (
MyTestEnum_FIRST MyTestEnum = 1
MyTestEnum_SECOND MyTestEnum = 2
MyTestEnum_THIRD MyTestEnum = 3
MyTestEnum_FOURTH MyTestEnum = 4
)
func (p MyTestEnum) String() string {
switch p {
case MyTestEnum_FIRST:
return "FIRST"
case MyTestEnum_SECOND:
return "SECOND"
case MyTestEnum_THIRD:
return "THIRD"
case MyTestEnum_FOURTH:
return "FOURTH"
}
return "<UNSET>"
}
func MyTestEnumFromString(s string) (MyTestEnum, error) {
switch s {
case "FIRST":
return MyTestEnum_FIRST, nil
case "SECOND":
return MyTestEnum_SECOND, nil
case "THIRD":
return MyTestEnum_THIRD, nil
case "FOURTH":
return MyTestEnum_FOURTH, nil
}
return MyTestEnum(0), fmt.Errorf("not a valid MyTestEnum string")
}
func MyTestEnumPtr(v MyTestEnum) *MyTestEnum { return &v }
type MyTestStruct struct {
On bool `thrift:"on,1" json:"on"`
B int8 `thrift:"b,2" json:"b"`
Int16 int16 `thrift:"int16,3" json:"int16"`
Int32 int32 `thrift:"int32,4" json:"int32"`
Int64 int64 `thrift:"int64,5" json:"int64"`
D float64 `thrift:"d,6" json:"d"`
St string `thrift:"st,7" json:"st"`
Bin []byte `thrift:"bin,8" json:"bin"`
StringMap map[string]string `thrift:"stringMap,9" json:"stringMap"`
StringList []string `thrift:"stringList,10" json:"stringList"`
StringSet map[string]struct{} `thrift:"stringSet,11" json:"stringSet"`
E MyTestEnum `thrift:"e,12" json:"e"`
}
func NewMyTestStruct() *MyTestStruct {
return &MyTestStruct{}
}
func (p *MyTestStruct) GetOn() bool {
return p.On
}
func (p *MyTestStruct) GetB() int8 {
return p.B
}
func (p *MyTestStruct) GetInt16() int16 {
return p.Int16
}
func (p *MyTestStruct) GetInt32() int32 {
return p.Int32
}
func (p *MyTestStruct) GetInt64() int64 {
return p.Int64
}
func (p *MyTestStruct) GetD() float64 {
return p.D
}
func (p *MyTestStruct) GetSt() string {
return p.St
}
func (p *MyTestStruct) GetBin() []byte {
return p.Bin
}
func (p *MyTestStruct) GetStringMap() map[string]string {
return p.StringMap
}
func (p *MyTestStruct) GetStringList() []string {
return p.StringList
}
func (p *MyTestStruct) GetStringSet() map[string]struct{} {
return p.StringSet
}
func (p *MyTestStruct) GetE() MyTestEnum {
return p.E
}
func (p *MyTestStruct) Read(iprot TProtocol) error {
if _, err := iprot.ReadStructBegin(); err != nil {
return PrependError(fmt.Sprintf("%T read error: ", p), err)
}
for {
_, fieldTypeId, fieldId, err := iprot.ReadFieldBegin()
if err != nil {
return PrependError(fmt.Sprintf("%T field %d read error: ", p, fieldId), err)
}
if fieldTypeId == STOP {
break
}
switch fieldId {
case 1:
if err := p.readField1(iprot); err != nil {
return err
}
case 2:
if err := p.readField2(iprot); err != nil {
return err
}
case 3:
if err := p.readField3(iprot); err != nil {
return err
}
case 4:
if err := p.readField4(iprot); err != nil {
return err
}
case 5:
if err := p.readField5(iprot); err != nil {
return err
}
case 6:
if err := p.readField6(iprot); err != nil {
return err
}
case 7:
if err := p.readField7(iprot); err != nil {
return err
}
case 8:
if err := p.readField8(iprot); err != nil {
return err
}
case 9:
if err := p.readField9(iprot); err != nil {
return err
}
case 10:
if err := p.readField10(iprot); err != nil {
return err
}
case 11:
if err := p.readField11(iprot); err != nil {
return err
}
case 12:
if err := p.readField12(iprot); err != nil {
return err
}
default:
if err := iprot.Skip(fieldTypeId); err != nil {
return err
}
}
if err := iprot.ReadFieldEnd(); err != nil {
return err
}
}
if err := iprot.ReadStructEnd(); err != nil {
return PrependError(fmt.Sprintf("%T read struct end error: ", p), err)
}
return nil
}
func (p *MyTestStruct) readField1(iprot TProtocol) error {
if v, err := iprot.ReadBool(); err != nil {
return PrependError("error reading field 1: ", err)
} else {
p.On = v
}
return nil
}
func (p *MyTestStruct) readField2(iprot TProtocol) error {
if v, err := iprot.ReadByte(); err != nil {
return PrependError("error reading field 2: ", err)
} else {
temp := int8(v)
p.B = temp
}
return nil
}
func (p *MyTestStruct) readField3(iprot TProtocol) error {
if v, err := iprot.ReadI16(); err != nil {
return PrependError("error reading field 3: ", err)
} else {
p.Int16 = v
}
return nil
}
func (p *MyTestStruct) readField4(iprot TProtocol) error {
if v, err := iprot.ReadI32(); err != nil {
return PrependError("error reading field 4: ", err)
} else {
p.Int32 = v
}
return nil
}
func (p *MyTestStruct) readField5(iprot TProtocol) error {
if v, err := iprot.ReadI64(); err != nil {
return PrependError("error reading field 5: ", err)
} else {
p.Int64 = v
}
return nil
}
func (p *MyTestStruct) readField6(iprot TProtocol) error {
if v, err := iprot.ReadDouble(); err != nil {
return PrependError("error reading field 6: ", err)
} else {
p.D = v
}
return nil
}
func (p *MyTestStruct) readField7(iprot TProtocol) error {
if v, err := iprot.ReadString(); err != nil {
return PrependError("error reading field 7: ", err)
} else {
p.St = v
}
return nil
}
func (p *MyTestStruct) readField8(iprot TProtocol) error {
if v, err := iprot.ReadBinary(); err != nil {
return PrependError("error reading field 8: ", err)
} else {
p.Bin = v
}
return nil
}
func (p *MyTestStruct) readField9(iprot TProtocol) error {
_, _, size, err := iprot.ReadMapBegin()
if err != nil {
return PrependError("error reading map begin: ", err)
}
tMap := make(map[string]string, size)
p.StringMap = tMap
for i := 0; i < size; i++ {
var _key0 string
if v, err := iprot.ReadString(); err != nil {
return PrependError("error reading field 0: ", err)
} else {
_key0 = v
}
var _val1 string
if v, err := iprot.ReadString(); err != nil {
return PrependError("error reading field 0: ", err)
} else {
_val1 = v
}
p.StringMap[_key0] = _val1
}
if err := iprot.ReadMapEnd(); err != nil {
return PrependError("error reading map end: ", err)
}
return nil
}
func (p *MyTestStruct) readField10(iprot TProtocol) error {
_, size, err := iprot.ReadListBegin()
if err != nil {
return PrependError("error reading list begin: ", err)
}
tSlice := make([]string, 0, size)
p.StringList = tSlice
for i := 0; i < size; i++ {
var _elem2 string
if v, err := iprot.ReadString(); err != nil {
return PrependError("error reading field 0: ", err)
} else {
_elem2 = v
}
p.StringList = append(p.StringList, _elem2)
}
if err := iprot.ReadListEnd(); err != nil {
return PrependError("error reading list end: ", err)
}
return nil
}
func (p *MyTestStruct) readField11(iprot TProtocol) error {
_, size, err := iprot.ReadSetBegin()
if err != nil {
return PrependError("error reading set begin: ", err)
}
tSet := make(map[string]struct{}, size)
p.StringSet = tSet
for i := 0; i < size; i++ {
var _elem3 string
if v, err := iprot.ReadString(); err != nil {
return PrependError("error reading field 0: ", err)
} else {
_elem3 = v
}
p.StringSet[_elem3] = struct{}{}
}
if err := iprot.ReadSetEnd(); err != nil {
return PrependError("error reading set end: ", err)
}
return nil
}
func (p *MyTestStruct) readField12(iprot TProtocol) error {
if v, err := iprot.ReadI32(); err != nil {
return PrependError("error reading field 12: ", err)
} else {
temp := MyTestEnum(v)
p.E = temp
}
return nil
}
func (p *MyTestStruct) Write(oprot TProtocol) error {
if err := oprot.WriteStructBegin("MyTestStruct"); err != nil {
return PrependError(fmt.Sprintf("%T write struct begin error: ", p), err)
}
if err := p.writeField1(oprot); err != nil {
return err
}
if err := p.writeField2(oprot); err != nil {
return err
}
if err := p.writeField3(oprot); err != nil {
return err
}
if err := p.writeField4(oprot); err != nil {
return err
}
if err := p.writeField5(oprot); err != nil {
return err
}
if err := p.writeField6(oprot); err != nil {
return err
}
if err := p.writeField7(oprot); err != nil {
return err
}
if err := p.writeField8(oprot); err != nil {
return err
}
if err := p.writeField9(oprot); err != nil {
return err
}
if err := p.writeField10(oprot); err != nil {
return err
}
if err := p.writeField11(oprot); err != nil {
return err
}
if err := p.writeField12(oprot); err != nil {
return err
}
if err := oprot.WriteFieldStop(); err != nil {
return PrependError("write field stop error: ", err)
}
if err := oprot.WriteStructEnd(); err != nil {
return PrependError("write struct stop error: ", err)
}
return nil
}
func (p *MyTestStruct) writeField1(oprot TProtocol) (err error) {
if err := oprot.WriteFieldBegin("on", BOOL, 1); err != nil {
return PrependError(fmt.Sprintf("%T write field begin error 1:on: ", p), err)
}
if err := oprot.WriteBool(bool(p.On)); err != nil {
return PrependError(fmt.Sprintf("%T.on (1) field write error: ", p), err)
}
if err := oprot.WriteFieldEnd(); err != nil {
return PrependError(fmt.Sprintf("%T write field end error 1:on: ", p), err)
}
return err
}
func (p *MyTestStruct) writeField2(oprot TProtocol) (err error) {
if err := oprot.WriteFieldBegin("b", BYTE, 2); err != nil {
return PrependError(fmt.Sprintf("%T write field begin error 2:b: ", p), err)
}
if err := oprot.WriteByte(int8(p.B)); err != nil {
return PrependError(fmt.Sprintf("%T.b (2) field write error: ", p), err)
}
if err := oprot.WriteFieldEnd(); err != nil {
return PrependError(fmt.Sprintf("%T write field end error 2:b: ", p), err)
}
return err
}
func (p *MyTestStruct) writeField3(oprot TProtocol) (err error) {
if err := oprot.WriteFieldBegin("int16", I16, 3); err != nil {
return PrependError(fmt.Sprintf("%T write field begin error 3:int16: ", p), err)
}
if err := oprot.WriteI16(int16(p.Int16)); err != nil {
return PrependError(fmt.Sprintf("%T.int16 (3) field write error: ", p), err)
}
if err := oprot.WriteFieldEnd(); err != nil {
return PrependError(fmt.Sprintf("%T write field end error 3:int16: ", p), err)
}
return err
}
func (p *MyTestStruct) writeField4(oprot TProtocol) (err error) {
if err := oprot.WriteFieldBegin("int32", I32, 4); err != nil {
return PrependError(fmt.Sprintf("%T write field begin error 4:int32: ", p), err)
}
if err := oprot.WriteI32(int32(p.Int32)); err != nil {
return PrependError(fmt.Sprintf("%T.int32 (4) field write error: ", p), err)
}
if err := oprot.WriteFieldEnd(); err != nil {
return PrependError(fmt.Sprintf("%T write field end error 4:int32: ", p), err)
}
return err
}
func (p *MyTestStruct) writeField5(oprot TProtocol) (err error) {
if err := oprot.WriteFieldBegin("int64", I64, 5); err != nil {
return PrependError(fmt.Sprintf("%T write field begin error 5:int64: ", p), err)
}
if err := oprot.WriteI64(int64(p.Int64)); err != nil {
return PrependError(fmt.Sprintf("%T.int64 (5) field write error: ", p), err)
}
if err := oprot.WriteFieldEnd(); err != nil {
return PrependError(fmt.Sprintf("%T write field end error 5:int64: ", p), err)
}
return err
}
func (p *MyTestStruct) writeField6(oprot TProtocol) (err error) {
if err := oprot.WriteFieldBegin("d", DOUBLE, 6); err != nil {
return PrependError(fmt.Sprintf("%T write field begin error 6:d: ", p), err)
}
if err := oprot.WriteDouble(float64(p.D)); err != nil {
return PrependError(fmt.Sprintf("%T.d (6) field write error: ", p), err)
}
if err := oprot.WriteFieldEnd(); err != nil {
return PrependError(fmt.Sprintf("%T write field end error 6:d: ", p), err)
}
return err
}
func (p *MyTestStruct) writeField7(oprot TProtocol) (err error) {
if err := oprot.WriteFieldBegin("st", STRING, 7); err != nil {
return PrependError(fmt.Sprintf("%T write field begin error 7:st: ", p), err)
}
if err := oprot.WriteString(string(p.St)); err != nil {
return PrependError(fmt.Sprintf("%T.st (7) field write error: ", p), err)
}
if err := oprot.WriteFieldEnd(); err != nil {
return PrependError(fmt.Sprintf("%T write field end error 7:st: ", p), err)
}
return err
}
func (p *MyTestStruct) writeField8(oprot TProtocol) (err error) {
if err := oprot.WriteFieldBegin("bin", STRING, 8); err != nil {
return PrependError(fmt.Sprintf("%T write field begin error 8:bin: ", p), err)
}
if err := oprot.WriteBinary(p.Bin); err != nil {
return PrependError(fmt.Sprintf("%T.bin (8) field write error: ", p), err)
}
if err := oprot.WriteFieldEnd(); err != nil {
return PrependError(fmt.Sprintf("%T write field end error 8:bin: ", p), err)
}
return err
}
func (p *MyTestStruct) writeField9(oprot TProtocol) (err error) {
if err := oprot.WriteFieldBegin("stringMap", MAP, 9); err != nil {
return PrependError(fmt.Sprintf("%T write field begin error 9:stringMap: ", p), err)
}
if err := oprot.WriteMapBegin(STRING, STRING, len(p.StringMap)); err != nil {
return PrependError("error writing map begin: ", err)
}
for k, v := range p.StringMap {
if err := oprot.WriteString(string(k)); err != nil {
return PrependError(fmt.Sprintf("%T. (0) field write error: ", p), err)
}
if err := oprot.WriteString(string(v)); err != nil {
return PrependError(fmt.Sprintf("%T. (0) field write error: ", p), err)
}
}
if err := oprot.WriteMapEnd(); err != nil {
return PrependError("error writing map end: ", err)
}
if err := oprot.WriteFieldEnd(); err != nil {
return PrependError(fmt.Sprintf("%T write field end error 9:stringMap: ", p), err)
}
return err
}
func (p *MyTestStruct) writeField10(oprot TProtocol) (err error) {
if err := oprot.WriteFieldBegin("stringList", LIST, 10); err != nil {
return PrependError(fmt.Sprintf("%T write field begin error 10:stringList: ", p), err)
}
if err := oprot.WriteListBegin(STRING, len(p.StringList)); err != nil {
return PrependError("error writing list begin: ", err)
}
for _, v := range p.StringList {
if err := oprot.WriteString(string(v)); err != nil {
return PrependError(fmt.Sprintf("%T. (0) field write error: ", p), err)
}
}
if err := oprot.WriteListEnd(); err != nil {
return PrependError("error writing list end: ", err)
}
if err := oprot.WriteFieldEnd(); err != nil {
return PrependError(fmt.Sprintf("%T write field end error 10:stringList: ", p), err)
}
return err
}
func (p *MyTestStruct) writeField11(oprot TProtocol) (err error) {
if err := oprot.WriteFieldBegin("stringSet", SET, 11); err != nil {
return PrependError(fmt.Sprintf("%T write field begin error 11:stringSet: ", p), err)
}
if err := oprot.WriteSetBegin(STRING, len(p.StringSet)); err != nil {
return PrependError("error writing set begin: ", err)
}
for v := range p.StringSet {
if err := oprot.WriteString(string(v)); err != nil {
return PrependError(fmt.Sprintf("%T. (0) field write error: ", p), err)
}
}
if err := oprot.WriteSetEnd(); err != nil {
return PrependError("error writing set end: ", err)
}
if err := oprot.WriteFieldEnd(); err != nil {
return PrependError(fmt.Sprintf("%T write field end error 11:stringSet: ", p), err)
}
return err
}
func (p *MyTestStruct) writeField12(oprot TProtocol) (err error) {
if err := oprot.WriteFieldBegin("e", I32, 12); err != nil {
return PrependError(fmt.Sprintf("%T write field begin error 12:e: ", p), err)
}
if err := oprot.WriteI32(int32(p.E)); err != nil {
return PrependError(fmt.Sprintf("%T.e (12) field write error: ", p), err)
}
if err := oprot.WriteFieldEnd(); err != nil {
return PrependError(fmt.Sprintf("%T write field end error 12:e: ", p), err)
}
return err
}
func (p *MyTestStruct) String() string {
if p == nil {
return "<nil>"
}
return fmt.Sprintf("MyTestStruct(%+v)", *p)
}

View file

@ -1,35 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
type TServer interface {
ProcessorFactory() TProcessorFactory
ServerTransport() TServerTransport
InputTransportFactory() TTransportFactory
OutputTransportFactory() TTransportFactory
InputProtocolFactory() TProtocolFactory
OutputProtocolFactory() TProtocolFactory
// Starts the server
Serve() error
// Stops the server. This is optional on a per-implementation basis. Not
// all servers are required to be cleanly stoppable.
Stop() error
}

View file

@ -1,137 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"net"
"sync"
"time"
)
type TServerSocket struct {
listener net.Listener
addr net.Addr
clientTimeout time.Duration
// Protects the interrupted value to make it thread safe.
mu sync.RWMutex
interrupted bool
}
func NewTServerSocket(listenAddr string) (*TServerSocket, error) {
return NewTServerSocketTimeout(listenAddr, 0)
}
func NewTServerSocketTimeout(listenAddr string, clientTimeout time.Duration) (*TServerSocket, error) {
addr, err := net.ResolveTCPAddr("tcp", listenAddr)
if err != nil {
return nil, err
}
return &TServerSocket{addr: addr, clientTimeout: clientTimeout}, nil
}
// Creates a TServerSocket from a net.Addr
func NewTServerSocketFromAddrTimeout(addr net.Addr, clientTimeout time.Duration) *TServerSocket {
return &TServerSocket{addr: addr, clientTimeout: clientTimeout}
}
func (p *TServerSocket) Listen() error {
p.mu.Lock()
defer p.mu.Unlock()
if p.IsListening() {
return nil
}
l, err := net.Listen(p.addr.Network(), p.addr.String())
if err != nil {
return err
}
p.listener = l
return nil
}
func (p *TServerSocket) Accept() (TTransport, error) {
p.mu.RLock()
interrupted := p.interrupted
p.mu.RUnlock()
if interrupted {
return nil, errTransportInterrupted
}
p.mu.Lock()
listener := p.listener
p.mu.Unlock()
if listener == nil {
return nil, NewTTransportException(NOT_OPEN, "No underlying server socket")
}
conn, err := listener.Accept()
if err != nil {
return nil, NewTTransportExceptionFromError(err)
}
return NewTSocketFromConnTimeout(conn, p.clientTimeout), nil
}
// Checks whether the socket is listening.
func (p *TServerSocket) IsListening() bool {
return p.listener != nil
}
// Connects the socket, creating a new socket object if necessary.
func (p *TServerSocket) Open() error {
p.mu.Lock()
defer p.mu.Unlock()
if p.IsListening() {
return NewTTransportException(ALREADY_OPEN, "Server socket already open")
}
if l, err := net.Listen(p.addr.Network(), p.addr.String()); err != nil {
return err
} else {
p.listener = l
}
return nil
}
func (p *TServerSocket) Addr() net.Addr {
if p.listener != nil {
return p.listener.Addr()
}
return p.addr
}
func (p *TServerSocket) Close() error {
var err error
p.mu.Lock()
if p.IsListening() {
err = p.listener.Close()
p.listener = nil
}
p.mu.Unlock()
return err
}
func (p *TServerSocket) Interrupt() error {
p.mu.Lock()
p.interrupted = true
p.mu.Unlock()
p.Close()
return nil
}

View file

@ -1,60 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package thrift
import (
"fmt"
"testing"
)
func TestSocketIsntListeningAfterInterrupt(t *testing.T) {
host := "127.0.0.1"
port := 9090
addr := fmt.Sprintf("%s:%d", host, port)
socket := CreateServerSocket(t, addr)
socket.Listen()
socket.Interrupt()
newSocket := CreateServerSocket(t, addr)
err := newSocket.Listen()
defer newSocket.Interrupt()
if err != nil {
t.Fatalf("Failed to rebinds: %s", err)
}
}
func TestSocketConcurrency(t *testing.T) {
host := "127.0.0.1"
port := 9090
addr := fmt.Sprintf("%s:%d", host, port)
socket := CreateServerSocket(t, addr)
go func() { socket.Listen() }()
go func() { socket.Interrupt() }()
}
func CreateServerSocket(t *testing.T, addr string) *TServerSocket {
socket, err := NewTServerSocket(addr)
if err != nil {
t.Fatalf("Failed to create server socket: %s", err)
}
return socket
}

Some files were not shown because too many files have changed in this diff Show more