diff --git a/.gitignore b/.gitignore index daf913b1..0c9e3298 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ # Folders _obj _test +.idea # Architecture specific extensions/prefixes *.[568vq] diff --git a/.travis.yml b/.travis.yml index 15e5a192..eaaabc55 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,14 +1,11 @@ +arch: + - amd64 + language: go -go_import_path: github.com/pkg/errors +go_import_path: github.com/pingcap/errors go: - - 1.4.x - - 1.5.x - - 1.6.x - - 1.7.x - - 1.8.x - - 1.9.x - - 1.10.x - - tip + - 1.13.x + - stable script: - go test -v ./... diff --git a/README.md b/README.md index 6483ba2a..b97656ac 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ Package errors provides simple error handling primitives. -`go get github.com/pkg/errors` +`go get github.com/pingcap/errors` The traditional error handling idiom in Go is roughly akin to ```go diff --git a/appveyor.yml b/appveyor.yml deleted file mode 100644 index a932eade..00000000 --- a/appveyor.yml +++ /dev/null @@ -1,32 +0,0 @@ -version: build-{build}.{branch} - -clone_folder: C:\gopath\src\github.com\pkg\errors -shallow_clone: true # for startup speed - -environment: - GOPATH: C:\gopath - -platform: - - x64 - -# http://www.appveyor.com/docs/installed-software -install: - # some helpful output for debugging builds - - go version - - go env - # pre-installed MinGW at C:\MinGW is 32bit only - # but MSYS2 at C:\msys64 has mingw64 - - set PATH=C:\msys64\mingw64\bin;%PATH% - - gcc --version - - g++ --version - -build_script: - - go install -v ./... - -test_script: - - set PATH=C:\gopath\bin;%PATH% - - go test -v ./... - -#artifacts: -# - path: '%GOPATH%\bin\*.exe' -deploy: off diff --git a/bench_test.go b/bench_test.go index 903b5f2d..b8ad7a02 100644 --- a/bench_test.go +++ b/bench_test.go @@ -1,12 +1,10 @@ -// +build go1.7 - package errors import ( + stderrors "errors" "fmt" + "strings" "testing" - - stderrors "errors" ) func noErrors(at, depth int) error { @@ -25,7 +23,7 @@ func yesErrors(at, depth int) error { // GlobalE is an exported global to store the result of benchmark results, // preventing the compiler from optimising the benchmark functions away. -var GlobalE error +var GlobalE any func BenchmarkErrors(b *testing.B) { type run struct { @@ -61,3 +59,141 @@ func BenchmarkErrors(b *testing.B) { }) } } + +func BenchmarkStackFormatting(b *testing.B) { + type run struct { + stack int + format string + } + runs := []run{ + {10, "%s"}, + {10, "%v"}, + {10, "%+v"}, + {30, "%s"}, + {30, "%v"}, + {30, "%+v"}, + {60, "%s"}, + {60, "%v"}, + {60, "%+v"}, + } + + var stackStr string + for _, r := range runs { + name := fmt.Sprintf("%s-stack-%d", r.format, r.stack) + b.Run(name, func(b *testing.B) { + err := yesErrors(0, r.stack) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + stackStr = fmt.Sprintf(r.format, err) + } + b.StopTimer() + }) + } + + for _, r := range runs { + name := fmt.Sprintf("%s-stacktrace-%d", r.format, r.stack) + b.Run(name, func(b *testing.B) { + err := yesErrors(0, r.stack) + st := err.(*fundamental).stack.StackTrace() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + stackStr = fmt.Sprintf(r.format, st) + } + b.StopTimer() + }) + } + GlobalE = stackStr +} + +type argsProfile struct { + name string + containsString bool + build func(count, stringLen int) []any +} + +type benchmarkHackedStr string + +func (s benchmarkHackedStr) FreezeStr() string { + return string(append([]byte(nil), s...)) +} + +func buildHackedStringArgs(count, stringLen int) []any { + arg := benchmarkHackedStr(strings.Repeat("x", stringLen)) + args := make([]any, count) + for i := range args { + args[i] = arg + } + return args +} + +func buildPlainStringArgs(count, stringLen int) []any { + arg := strings.Repeat("x", stringLen) + args := make([]any, count) + for i := range args { + args[i] = arg + } + return args +} + +func buildIntArgs(count, _ int) []any { + args := make([]any, count) + for i := range args { + args[i] = i + } + return args +} + +func BenchmarkByArgsHackedStrFreeze(b *testing.B) { + errPrototype := Normalize("bench", RFCCodeText("Internal:Bench")) + + apiCases := []struct { + name string + call func(errPrototype *Error, args []any) error + }{ + { + name: "FastGenByArgs", + call: func(errPrototype *Error, args []any) error { + return errPrototype.FastGenByArgs(args...) + }, + }, + } + profiles := []argsProfile{ + {name: "plain", containsString: true, build: buildPlainStringArgs}, + {name: "hacked", containsString: true, build: buildHackedStringArgs}, + } + argCounts := []int{1, 4, 8} + stringLens := []int{16, 1024} + + for _, apiCase := range apiCases { + b.Run(apiCase.name, func(b *testing.B) { + for _, profile := range profiles { + lens := []int{0} + if profile.containsString { + lens = stringLens + } + for _, argCount := range argCounts { + for _, strLen := range lens { + templateArgs := profile.build(argCount, strLen) + caseName := fmt.Sprintf("type-%s/count-%d", profile.name, argCount) + if profile.containsString { + caseName = fmt.Sprintf("%s/strlen-%d", caseName, strLen) + } + + b.Run(caseName, func(b *testing.B) { + var err error + args := make([]any, len(templateArgs)) + b.ReportAllocs() + for i := 0; i < b.N; i++ { + copy(args, templateArgs) + err = apiCase.call(errPrototype, args) + } + GlobalE = err + }) + } + } + } + }) + } +} diff --git a/compatible_shim.go b/compatible_shim.go new file mode 100644 index 00000000..0835a0cd --- /dev/null +++ b/compatible_shim.go @@ -0,0 +1,98 @@ +// Copyright 2020 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package errors + +import ( + "encoding/json" + "strconv" + "strings" +) + +// class2RFCCode is used for compatible with old version of TiDB. When +// marshal Error to json, old version of TiDB contain a 'class' field +// which is represented for error class. In order to parse and convert +// json to errors.Error, using this map to convert error class to RFC +// error code text. here is reference: +// https://github.com/pingcap/parser/blob/release-3.0/terror/terror.go#L58 +var class2RFCCode = map[int]string{ + 1: "autoid", + 2: "ddl", + 3: "domain", + 4: "evaluator", + 5: "executor", + 6: "expression", + 7: "admin", + 8: "kv", + 9: "meta", + 10: "planner", + 11: "parser", + 12: "perfschema", + 13: "privilege", + 14: "schema", + 15: "server", + 16: "struct", + 17: "variable", + 18: "xeval", + 19: "table", + 20: "types", + 21: "global", + 22: "mocktikv", + 23: "json", + 24: "tikv", + 25: "session", + 26: "plugin", + 27: "util", +} +var rfcCode2class map[string]int + +func init() { + rfcCode2class = make(map[string]int) + for k, v := range class2RFCCode { + rfcCode2class[v] = k + } +} + +// MarshalJSON implements json.Marshaler interface. +// aware that this function cannot save a 'registered' status, +// since we cannot access the registry when unmarshaling, +// and the original global registry would be removed here. +// This function is reserved for compatibility. +func (e *Error) MarshalJSON() ([]byte, error) { + ec := strings.Split(string(e.codeText), ":")[0] + return json.Marshal(&jsonError{ + Class: rfcCode2class[ec], + Code: int(e.code), + Msg: e.GetMsg(), + RFCCode: string(e.codeText), + }) +} + +// UnmarshalJSON implements json.Unmarshaler interface. +// aware that this function cannot create a 'registered' error, +// since we cannot access the registry in this context, +// and the original global registry is removed. +// This function is reserved for compatibility. +func (e *Error) UnmarshalJSON(data []byte) error { + tErr := &jsonError{} + if err := json.Unmarshal(data, &tErr); err != nil { + return Trace(err) + } + e.codeText = ErrCodeText(tErr.RFCCode) + if tErr.RFCCode == "" && tErr.Class > 0 { + e.codeText = ErrCodeText(class2RFCCode[tErr.Class] + ":" + strconv.Itoa(tErr.Code)) + } + e.code = ErrCode(tErr.Code) + e.message = tErr.Msg + return nil +} diff --git a/errdoc-gen/README.md b/errdoc-gen/README.md new file mode 100644 index 00000000..8f98afcc --- /dev/null +++ b/errdoc-gen/README.md @@ -0,0 +1,6 @@ +## Usage + +```shell script +# eg: ./errdoc-gen --source devel/pingap/tidb --module github.com/pingcap/tidb --output devel/pingap/tidb/errors.toml +./errdoc-gen --source /path/to/source/code --module ${module-name} --output /path/to/errors.toml +``` \ No newline at end of file diff --git a/errdoc-gen/main.go b/errdoc-gen/main.go new file mode 100644 index 00000000..a5e4df33 --- /dev/null +++ b/errdoc-gen/main.go @@ -0,0 +1,349 @@ +// Copyright 2020 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "flag" + "fmt" + "go/ast" + "go/parser" + "go/token" + "os" + "os/exec" + "path/filepath" + "slices" + "strings" + "text/template" +) + +var opt struct { + source string + module string + output string + ignore string + retainCode bool +} + +func init() { + flag.StringVar(&opt.source, "source", "", "The source directory of error documentation") + flag.StringVar(&opt.module, "module", "", "The module name of target repository") + flag.StringVar(&opt.output, "output", "", "The output path of error documentation file") + flag.StringVar(&opt.ignore, "ignore", "", "Directories to ignore, splitted by comma") + flag.BoolVar(&opt.retainCode, "retain-code", false, "Retain the generated code when generator exit") +} + +func log(format string, args ...any) { + fmt.Println(fmt.Sprintf(format, args...)) +} + +func fatal(format string, args ...any) { + log(format, args...) + os.Exit(1) +} + +const autoDirectoryName = "_errdoc-generator" +const entryFileName = "main.go" + +func main() { + flag.Parse() + if opt.source == "" { + fatal("The source directory cannot be empty") + } + + source, err := filepath.EvalSymlinks(opt.source) + if err != nil { + fatal("Evaluate symbol link path %s failed: %v", opt.source, err) + } + opt.source = source + + _, err = os.Stat(filepath.Join(opt.source, "go.mod")) + if os.IsNotExist(err) { + fatal("The source directory is not the root path of codebase(go.mod not found)") + } + + if opt.output == "" { + opt.output = filepath.Join(opt.source, "errors.toml") + log("The --output argument is missing, default to %s", opt.output) + } + + errNames, err := errdoc(opt.source, opt.module) + if err != nil { + log("Extract the error documentation failed: %+v", err) + } + + targetDir := filepath.Join(opt.source, autoDirectoryName) + if err := os.MkdirAll(targetDir, 0755); err != nil { + fatal("Cannot create the errdoc: %+v", err) + } + + if !opt.retainCode { + defer os.RemoveAll(targetDir) + } + + tmpl := ` +package main + +import ( + "bytes" + "flag" + "io/ioutil" + "os" + "reflect" + "fmt" + "sort" + "strings" + + "github.com/BurntSushi/toml" + "github.com/pingcap/errors" +{{- range $decl := .}} + {{$decl.PackageName}} "{{- $decl.ImportPath}}" +{{- end}} +) + +func main() { + var outpath string + flag.StringVar(&outpath, "output", "", "Specify the error documentation output file path") + flag.Parse() + if outpath == "" { + println("Usage: ./_errdoc-generator --output /path/to/errors.toml") + os.Exit(1) + } + + // Read-in the exists file and merge the description/workaround from exists file + existDefinition := map[string]spec{} + if file, err := ioutil.ReadFile(outpath); err == nil { + err = toml.Unmarshal(file, &existDefinition) + if err != nil { + println(fmt.Sprintf("Invalid toml file %s when merging exists description/workaround: %v", outpath, err)) + os.Exit(1) + } + } + + var allErrors []error + {{- range $decl := .}} + {{- range $err := $decl.ErrNames}} + allErrors = append(allErrors, {{$decl.PackageName}}.{{- $err}}) + {{- end}} + {{- end}} + + var dedup = map[string]spec{} + for _, e := range allErrors { + terr, ok := e.(*errors.Error) + if !ok { + println("Non-normalized error:", e.Error()) + } else { + val := reflect.ValueOf(terr).Elem() + codeText := val.FieldByName("codeText") + message := val.FieldByName("message") + if previous, found := dedup[codeText.String()]; found { + println("Duplicated error code:", codeText.String()) + if message.String() < previous.Error { + continue + } + } + s := spec{ + Code: codeText.String(), + Error: message.String(), + } + if exist, found := existDefinition[s.Code]; found { + s.Description = strings.TrimSpace(exist.Description) + s.Workaround = strings.TrimSpace(exist.Workaround) + } + dedup[codeText.String()] = s + } + } + + var sorted []spec + for _, item := range dedup { + sorted = append(sorted, item) + } + sort.Slice(sorted, func(i, j int) bool { + // TiDB exits duplicated code + if sorted[i].Code == sorted[j].Code { + return sorted[i].Error < sorted[j].Error + } + return sorted[i].Code < sorted[j].Code + }) + + // We don't use toml library to serialize it due to cannot reserve the order for map[string]spec + buffer := bytes.NewBufferString("# AUTOGENERATED BY github.com/pingcap/errors/errdoc-gen\n" + + "# YOU CAN CHANGE THE 'description'/'workaround' FIELDS IF THEM ARE IMPROPER.\n\n") + for _, item := range sorted { + buffer.WriteString(fmt.Sprintf("[\"%s\"]\nerror = '''\n%s\n'''\n", item.Code, item.Error)) + if item.Description != "" { + buffer.WriteString(fmt.Sprintf("description = '''\n%s\n'''\n", item.Description)) + } + if item.Workaround != "" { + buffer.WriteString(fmt.Sprintf("workaround = '''\n%s\n'''\n", item.Workaround)) + } + buffer.WriteString("\n") + } + if err := ioutil.WriteFile(outpath, buffer.Bytes(), 0644); err != nil { + panic(err) + } +} +` + "type spec struct {\n" + + "Code string\n" + + "Error string `toml:\"error\"`\n" + + "Description string `toml:\"description\"`\n" + + "Workaround string `toml:\"workaround\"`\n" + + "}" + + t, err := template.New("_errdoc-template").Parse(tmpl) + if err != nil { + fatal("Parse template failed: %+v", err) + } + + outFile := filepath.Join(targetDir, entryFileName) + out, err := os.OpenFile(outFile, os.O_TRUNC|os.O_WRONLY|os.O_CREATE, 0644) + if err != nil { + fatal("Open %s failed: %+v", outFile, err) + } + defer out.Close() + + if err := t.Execute(out, errNames); err != nil { + fatal("Render template failed: %+v", err) + } + + output, err := filepath.Abs(opt.output) + if err != nil { + fatal("Evaluate the absolute path of output failed: %+v", err) + } + + cmd := exec.Command("go", "run", filepath.Join(autoDirectoryName, entryFileName), "--output", output) + cmd.Dir = opt.source + data, err := cmd.CombinedOutput() + if err != nil { + fatal("Generate %s failed: %v, output:\n%s", opt.output, err, string(data)) + } + + log("Generate successfully, output:\n%s", string(data)) +} + +type errDecl struct { + ImportPath string + PackageName string + ErrNames []string +} + +func errdoc(source, module string) ([]*errDecl, error) { + source, err := filepath.Abs(source) + if err != nil { + return nil, err + } + + dedup := map[string]*errDecl{} + + ignored := strings.Split(opt.ignore, ",") + for i := range ignored { + ignored[i] = filepath.Join(source, ignored[i]) + } + err = filepath.Walk(source, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if info.IsDir() { + if slices.Contains(ignored, path) { + return filepath.SkipDir + } + return nil + } + if !strings.HasSuffix(path, ".go") { + return nil + } + fset := token.NewFileSet() + file, err := parser.ParseFile(fset, path, nil, parser.ParseComments) + if err != nil { + // Ignore invalid source file + return nil + } + errNames := export(file) + if len(errNames) < 1 { + return nil + } + dirPath := filepath.Dir(path) + subPath, err := filepath.Rel(source, dirPath) + if err != nil { + return err + } + packageName := strings.ReplaceAll(subPath, "/", "_") + if decl, found := dedup[packageName]; found { + decl.ErrNames = append(decl.ErrNames, errNames...) + } else { + decl := &errDecl{ + ImportPath: filepath.Join(module, subPath), + PackageName: packageName, + ErrNames: errNames, + } + dedup[packageName] = decl + } + return nil + }) + + var errDecls []*errDecl + for _, decl := range dedup { + errDecls = append(errDecls, decl) + } + + return errDecls, err +} + +func export(f *ast.File) []string { + if len(f.Decls) == 0 { + return nil + } + + var errNames []string + for _, decl := range f.Decls { + gen, ok := decl.(*ast.GenDecl) + if !ok || len(gen.Specs) == 0 { + continue + } + for _, spec := range gen.Specs { + switch errSpec := spec.(type) { + case *ast.ValueSpec: + // CASES: + // var ErrXXX = errors.Normalize(string, opts...) + // var ( + // ErrYYY = errors.Normalize(string, opts...) + // ErrZZZ = errors.Normalize(string, opts...) + // A = errors.Normalize(string, opts...) + // ) + // var ErrXXX, ErrYYY = errors.Normalize(string, opts...), errors.Normalize(string, opts...) + // var ( + // ErrYYY = errors.Normalize(string, opts...) + // ErrZZZ, ErrWWW = errors.Normalize(string, opts...), errors.Normalize(string, opts...) + // A = errors.Normalize(string, opts...) + // ) + // + if len(errSpec.Names) != len(errSpec.Values) { + continue + } + for i, name := range errSpec.Names { + if !strings.HasPrefix(name.Name, "Err") { + continue + } + _, ok := errSpec.Values[i].(*ast.CallExpr) + if !ok { + continue + } + errNames = append(errNames, name.Name) + } + default: + continue + } + } + } + return errNames +} diff --git a/errors.go b/errors.go index 842ee804..2ac469b4 100644 --- a/errors.go +++ b/errors.go @@ -2,93 +2,88 @@ // // The traditional error handling idiom in Go is roughly akin to // -// if err != nil { -// return err -// } +// if err != nil { +// return err +// } // // which applied recursively up the call stack results in error reports // without context or debugging information. The errors package allows // programmers to add context to the failure path in their code in a way // that does not destroy the original value of the error. // -// Adding context to an error +// # Adding context to an error // -// The errors.Wrap function returns a new error that adds context to the -// original error by recording a stack trace at the point Wrap is called, +// The errors.Annotate function returns a new error that adds context to the +// original error by recording a stack trace at the point Annotate is called, // and the supplied message. For example // -// _, err := ioutil.ReadAll(r) -// if err != nil { -// return errors.Wrap(err, "read failed") -// } +// _, err := ioutil.ReadAll(r) +// if err != nil { +// return errors.Annotate(err, "read failed") +// } // -// If additional control is required the errors.WithStack and errors.WithMessage -// functions destructure errors.Wrap into its component operations of annotating +// If additional control is required the errors.AddStack and errors.WithMessage +// functions destructure errors.Annotate into its component operations of annotating // an error with a stack trace and an a message, respectively. // -// Retrieving the cause of an error +// # Retrieving the cause of an error // -// Using errors.Wrap constructs a stack of errors, adding context to the +// Using errors.Annotate constructs a stack of errors, adding context to the // preceding error. Depending on the nature of the error it may be necessary -// to reverse the operation of errors.Wrap to retrieve the original error +// to reverse the operation of errors.Annotate to retrieve the original error // for inspection. Any error value which implements this interface // -// type causer interface { -// Cause() error -// } +// type causer interface { +// Cause() error +// } // // can be inspected by errors.Cause. errors.Cause will recursively retrieve // the topmost error which does not implement causer, which is assumed to be // the original cause. For example: // -// switch err := errors.Cause(err).(type) { -// case *MyError: -// // handle specifically -// default: -// // unknown error -// } +// switch err := errors.Cause(err).(type) { +// case *MyError: +// // handle specifically +// default: +// // unknown error +// } // // causer interface is not exported by this package, but is considered a part // of stable public API. +// errors.Unwrap is also available: this will retrieve the next error in the chain. // -// Formatted printing of errors +// # Formatted printing of errors // // All error values returned from this package implement fmt.Formatter and can // be formatted by the fmt package. The following verbs are supported // -// %s print the error. If the error has a Cause it will be -// printed recursively -// %v see %s -// %+v extended format. Each Frame of the error's StackTrace will -// be printed in detail. +// %s print the error. If the error has a Cause it will be +// printed recursively +// %v see %s +// %+v extended format. Each Frame of the error's StackTrace will +// be printed in detail. // -// Retrieving the stack trace of an error or wrapper +// # Retrieving the stack trace of an error or wrapper // -// New, Errorf, Wrap, and Wrapf record a stack trace at the point they are -// invoked. This information can be retrieved with the following interface. +// New, Errorf, Annotate, and Annotatef record a stack trace at the point they are invoked. +// This information can be retrieved with the StackTracer interface that returns +// a StackTrace. Where errors.StackTrace is defined as // -// type stackTracer interface { -// StackTrace() errors.StackTrace -// } -// -// Where errors.StackTrace is defined as -// -// type StackTrace []Frame +// type StackTrace []Frame // // The Frame type represents a call site in the stack trace. Frame supports // the fmt.Formatter interface that can be used for printing information about // the stack trace of this error. For example: // -// if err, ok := err.(stackTracer); ok { -// for _, f := range err.StackTrace() { -// fmt.Printf("%+s:%d", f) -// } -// } -// -// stackTracer interface is not exported by this package, but is considered a part -// of stable public API. +// if stacked := errors.GetStackTracer(err); stacked != nil { +// for _, f := range stacked.StackTrace() { +// fmt.Printf("%+s:%d\n", f, f) +// } +// } // // See the documentation for Frame.Format for more details. +// +// errors.Find can be used to search for an error in the error chain. package errors import ( @@ -96,6 +91,12 @@ import ( "io" ) +// represent an error carries with message +type messenger interface { + // GetSelfMsg get its own message, the message of its cause error is NOT included. + GetSelfMsg() string +} + // New returns an error with the supplied message. // New also records the stack trace at the point it was called. func New(message string) error { @@ -108,21 +109,44 @@ func New(message string) error { // Errorf formats according to a format specifier and returns the string // as a value that satisfies error. // Errorf also records the stack trace at the point it was called. -func Errorf(format string, args ...interface{}) error { +func Errorf(format string, args ...any) error { return &fundamental{ msg: fmt.Sprintf(format, args...), stack: callers(), } } +// StackTraceAware is an optimization to avoid repetitive traversals of an error chain. +// HasStack checks for this marker first. +// Annotate/Wrap and Annotatef/Wrapf will produce this marker. +type StackTraceAware interface { + HasStack() bool +} + +// HasStack tells whether a StackTracer exists in the error chain +func HasStack(err error) bool { + if errWithStack, ok := err.(StackTraceAware); ok { + return errWithStack.HasStack() + } + // Error.FastGenXXX or call SuspendStack directly will make an empty stack trace, + // which should be considered as no stack trace, to allow upper layer code to + // add stack trace with Trace. + stackTracer := GetStackTracer(err) + return stackTracer != nil && !stackTracer.Empty() +} + // fundamental is an error that has a message and a stack, but no caller. type fundamental struct { msg string *stack } +var _ messenger = (*fundamental)(nil) + func (f *fundamental) Error() string { return f.msg } +func (f *fundamental) GetSelfMsg() string { return f.msg } + func (f *fundamental) Format(s fmt.State, verb rune) { switch verb { case 'v': @@ -141,10 +165,27 @@ func (f *fundamental) Format(s fmt.State, verb rune) { // WithStack annotates err with a stack trace at the point WithStack was called. // If err is nil, WithStack returns nil. +// +// For most use cases this is deprecated and AddStack should be used (which will ensure just one stack trace). +// However, one may want to use this in some situations, for example to create a 2nd trace across a goroutine. func WithStack(err error) error { if err == nil { return nil } + + return &withStack{ + err, + callers(), + } +} + +// AddStack is similar to WithStack. +// However, it will first check with HasStack to see if a stack trace already exists in the causer chain before creating another one. +func AddStack(err error) error { + if err == nil || HasStack(err) { + return err + } + return &withStack{ err, callers(), @@ -156,8 +197,19 @@ type withStack struct { *stack } +var _ messenger = (*withStack)(nil) + func (w *withStack) Cause() error { return w.error } +func (w *withStack) GetSelfMsg() string { + // it doesn't have its own message, but we still need impl it to avoid calling + // err.Error() for its cause + return "" +} + +// Unwrap provides compatibility for Go 1.13 error chains. +func (w *withStack) Unwrap() error { return w.error } + func (w *withStack) Format(s fmt.State, verb rune) { switch verb { case 'v': @@ -177,13 +229,18 @@ func (w *withStack) Format(s fmt.State, verb rune) { // Wrap returns an error annotating err with a stack trace // at the point Wrap is called, and the supplied message. // If err is nil, Wrap returns nil. +// +// For most use cases this is deprecated in favor of Annotate. +// Annotate avoids creating duplicate stack traces. func Wrap(err error, message string) error { if err == nil { return nil } + hasStack := HasStack(err) err = &withMessage{ - cause: err, - msg: message, + cause: err, + msg: message, + causeHasStack: hasStack, } return &withStack{ err, @@ -194,13 +251,18 @@ func Wrap(err error, message string) error { // Wrapf returns an error annotating err with a stack trace // at the point Wrapf is call, and the format specifier. // If err is nil, Wrapf returns nil. -func Wrapf(err error, format string, args ...interface{}) error { +// +// For most use cases this is deprecated in favor of Annotatef. +// Annotatef avoids creating duplicate stack traces. +func Wrapf(err error, format string, args ...any) error { if err == nil { return nil } + hasStack := HasStack(err) err = &withMessage{ - cause: err, - msg: fmt.Sprintf(format, args...), + cause: err, + msg: fmt.Sprintf(format, args...), + causeHasStack: hasStack, } return &withStack{ err, @@ -215,19 +277,29 @@ func WithMessage(err error, message string) error { return nil } return &withMessage{ - cause: err, - msg: message, + cause: err, + msg: message, + causeHasStack: HasStack(err), } } type withMessage struct { - cause error - msg string + cause error + msg string + causeHasStack bool } +var _ messenger = (*withMessage)(nil) + func (w *withMessage) Error() string { return w.msg + ": " + w.cause.Error() } func (w *withMessage) Cause() error { return w.cause } +func (w *withMessage) GetSelfMsg() string { return w.msg } + +// Unwrap provides compatibility for Go 1.13 error chains. +func (w *withMessage) Unwrap() error { return w.cause } +func (w *withMessage) HasStack() bool { return w.causeHasStack } + func (w *withMessage) Format(s fmt.State, verb rune) { switch verb { case 'v': @@ -237,8 +309,10 @@ func (w *withMessage) Format(s fmt.State, verb rune) { return } fallthrough - case 's', 'q': + case 's': io.WriteString(s, w.Error()) + case 'q': + fmt.Fprintf(s, "%q", w.Error()) } } @@ -246,24 +320,73 @@ func (w *withMessage) Format(s fmt.State, verb rune) { // An error value has a cause if it implements the following // interface: // -// type causer interface { -// Cause() error -// } +// type causer interface { +// Cause() error +// } // // If the error does not implement Cause, the original error will // be returned. If the error is nil, nil will be returned without further // investigation. func Cause(err error) error { + cause := Unwrap(err) + if cause == nil { + return err + } + return Cause(cause) +} + +// Unwrap uses causer to return the next error in the chain or nil. +// This goes one-level deeper, whereas Cause goes as far as possible +func Unwrap(err error) error { type causer interface { Cause() error } + if unErr, ok := err.(causer); ok { + return unErr.Cause() + } + return nil +} + +// Find an error in the chain that matches a test function. +// returns nil if no error is found. +func Find(origErr error, test func(error) bool) error { + var foundErr error + WalkDeep(origErr, func(err error) bool { + if test(err) { + foundErr = err + return true + } + return false + }) + return foundErr +} - for err != nil { - cause, ok := err.(causer) - if !ok { - break +// GetErrStackMsg get the concat error message the whole error stack. +// it's different from err.Error(), as pingcap/errors.Error will prepend the error +// code in the result of err.Error(), like below: +// +// [types:1292]Truncated incorrect +// +// and when there are multiple errors.Error in the chain, the err.Error() will +// return like this: +// +// [Lightning:Restore:ErrEncodeKV]encode kv error ... : [types:1292]Truncated incorrect DOUBLE value: 'a'" +// +// But sometimes we only want a single error code with pure message part. +func GetErrStackMsg(err error) string { + if err == nil { + return "" + } + m, ok := err.(messenger) + if ok { + msg := m.GetSelfMsg() + causeMsg := GetErrStackMsg(Unwrap(err)) + if msg == "" { + msg = causeMsg + } else if causeMsg != "" { + msg = msg + ": " + causeMsg } - err = cause.Cause() + return msg } - return err + return err.Error() } diff --git a/errors_test.go b/errors_test.go index c4e6eef6..7a26dc8f 100644 --- a/errors_test.go +++ b/errors_test.go @@ -4,8 +4,12 @@ import ( "errors" "fmt" "io" + "net/url" "reflect" + "strconv" "testing" + + "github.com/stretchr/testify/require" ) func TestNew(t *testing.T) { @@ -28,7 +32,7 @@ func TestNew(t *testing.T) { } func TestWrapNil(t *testing.T) { - got := Wrap(nil, "no error") + got := Annotate(nil, "no error") if got != nil { t.Errorf("Wrap(nil, \"no error\"): got %#v, expected nil", got) } @@ -41,11 +45,11 @@ func TestWrap(t *testing.T) { want string }{ {io.EOF, "read error", "read error: EOF"}, - {Wrap(io.EOF, "read error"), "client error", "client error: read error: EOF"}, + {Annotate(io.EOF, "read error"), "client error", "client error: read error: EOF"}, } for _, tt := range tests { - got := Wrap(tt.err, tt.message).Error() + got := Annotate(tt.err, tt.message).Error() if got != tt.want { t.Errorf("Wrap(%v, %q): got: %v, want %v", tt.err, tt.message, got, tt.want) } @@ -79,7 +83,7 @@ func TestCause(t *testing.T) { want: io.EOF, }, { // caused error returns cause - err: Wrap(io.EOF, "ignored"), + err: Annotate(io.EOF, "ignored"), want: io.EOF, }, { err: x, // return from errors.New @@ -96,6 +100,12 @@ func TestCause(t *testing.T) { }, { WithStack(io.EOF), io.EOF, + }, { + AddStack(nil), + nil, + }, { + AddStack(io.EOF), + io.EOF, }} for i, tt := range tests { @@ -107,7 +117,7 @@ func TestCause(t *testing.T) { } func TestWrapfNil(t *testing.T) { - got := Wrapf(nil, "no error") + got := Annotate(nil, "no error") if got != nil { t.Errorf("Wrapf(nil, \"no error\"): got %#v, expected nil", got) } @@ -120,12 +130,12 @@ func TestWrapf(t *testing.T) { want string }{ {io.EOF, "read error", "read error: EOF"}, - {Wrapf(io.EOF, "read error without format specifiers"), "client error", "client error: read error without format specifiers: EOF"}, - {Wrapf(io.EOF, "read error with %d format specifier", 1), "client error", "client error: read error with 1 format specifier: EOF"}, + {Annotatef(io.EOF, "read error without format specifiers"), "client error", "client error: read error without format specifiers: EOF"}, + {Annotatef(io.EOF, "read error with %d format specifier", 1), "client error", "client error: read error with 1 format specifier: EOF"}, } for _, tt := range tests { - got := Wrapf(tt.err, tt.message).Error() + got := Annotatef(tt.err, "%s", tt.message).Error() if got != tt.want { t.Errorf("Wrapf(%v, %q): got: %v, want %v", tt.err, tt.message, got, tt.want) } @@ -154,6 +164,10 @@ func TestWithStackNil(t *testing.T) { if got != nil { t.Errorf("WithStack(nil): got %#v, expected nil", got) } + got = AddStack(nil) + if got != nil { + t.Errorf("AddStack(nil): got %#v, expected nil", got) + } } func TestWithStack(t *testing.T) { @@ -173,6 +187,50 @@ func TestWithStack(t *testing.T) { } } +func TestAddStack(t *testing.T) { + tests := []struct { + err error + want string + }{ + {io.EOF, "EOF"}, + {AddStack(io.EOF), "EOF"}, + } + + for _, tt := range tests { + got := AddStack(tt.err).Error() + if got != tt.want { + t.Errorf("AddStack(%v): got: %v, want %v", tt.err, got, tt.want) + } + } +} + +func TestGetStackTracer(t *testing.T) { + orig := io.EOF + if GetStackTracer(orig) != nil { + t.Errorf("GetStackTracer: got: %v, want %v", GetStackTracer(orig), nil) + } + stacked := AddStack(orig) + if GetStackTracer(stacked).(error) != stacked { + t.Errorf("GetStackTracer(stacked): got: %v, want %v", GetStackTracer(stacked), stacked) + } + final := AddStack(stacked) + if GetStackTracer(final).(error) != stacked { + t.Errorf("GetStackTracer(final): got: %v, want %v", GetStackTracer(final), stacked) + } +} + +func TestAddStackDedup(t *testing.T) { + stacked := WithStack(io.EOF) + err := AddStack(stacked) + if err != stacked { + t.Errorf("AddStack: got: %+v, want %+v", err, stacked) + } + err = WithStack(stacked) + if err == stacked { + t.Errorf("WithStack: got: %v, don't want %v", err, stacked) + } +} + func TestWithMessageNil(t *testing.T) { got := WithMessage(nil, "no error") if got != nil { @@ -209,12 +267,14 @@ func TestErrorEquality(t *testing.T) { errors.New("EOF"), New("EOF"), Errorf("EOF"), - Wrap(io.EOF, "EOF"), - Wrapf(io.EOF, "EOF%d", 2), + Annotate(io.EOF, "EOF"), + Annotatef(io.EOF, "EOF%d", 2), WithMessage(nil, "whoops"), WithMessage(io.EOF, "whoops"), WithStack(io.EOF), WithStack(nil), + AddStack(io.EOF), + AddStack(nil), } for i := range vals { @@ -223,3 +283,242 @@ func TestErrorEquality(t *testing.T) { } } } + +func TestFind(t *testing.T) { + eNew := errors.New("error") + wrapped := Annotate(nilError{}, "nil") + tests := []struct { + err error + finder func(error) bool + found error + }{ + {io.EOF, func(_ error) bool { return true }, io.EOF}, + {io.EOF, func(_ error) bool { return false }, nil}, + {io.EOF, func(err error) bool { return err == io.EOF }, io.EOF}, + {io.EOF, func(err error) bool { return err != io.EOF }, nil}, + + {eNew, func(err error) bool { return true }, eNew}, + {eNew, func(err error) bool { return false }, nil}, + + {nilError{}, func(err error) bool { return true }, nilError{}}, + {nilError{}, func(err error) bool { return false }, nil}, + {nilError{}, func(err error) bool { _, ok := err.(nilError); return ok }, nilError{}}, + + {wrapped, func(err error) bool { return true }, wrapped}, + {wrapped, func(err error) bool { return false }, nil}, + {wrapped, func(err error) bool { _, ok := err.(nilError); return ok }, nilError{}}, + } + + for _, tt := range tests { + got := Find(tt.err, tt.finder) + if got != tt.found { + t.Errorf("WithMessage(%v): got: %q, want %q", tt.err, got, tt.found) + } + } +} + +type errWalkTest struct { + cause error + sub []error + v int +} + +func (e *errWalkTest) Error() string { + return strconv.Itoa(e.v) +} + +func (e *errWalkTest) Cause() error { + return e.cause +} + +func (e *errWalkTest) Errors() []error { + return e.sub +} + +func testFind(err error, v int) bool { + return WalkDeep(err, func(err error) bool { + e := err.(*errWalkTest) + return e.v == v + }) +} + +func TestWalkDeep(t *testing.T) { + err := &errWalkTest{ + sub: []error{ + &errWalkTest{ + v: 10, + cause: &errWalkTest{v: 11}, + }, + &errWalkTest{ + v: 20, + cause: &errWalkTest{v: 21, cause: &errWalkTest{v: 22}}, + }, + &errWalkTest{ + v: 30, + cause: &errWalkTest{v: 31}, + }, + }, + } + + if !testFind(err, 11) { + t.Errorf("not found in first cause chain") + } + + if !testFind(err, 22) { + t.Errorf("not found in siblings") + } + + if testFind(err, 32) { + t.Errorf("found not exists") + } +} + +func TestWalkDeepNil(t *testing.T) { + require.False(t, WalkDeep(nil, func(err error) bool { return true })) +} + +func TestWalkDeepComplexTree(t *testing.T) { + err := &errWalkTest{v: 1, cause: &errWalkTest{ + sub: []error{ + &errWalkTest{ + v: 10, + cause: &errWalkTest{v: 11}, + }, + &errWalkTest{ + v: 20, + sub: []error{ + &errWalkTest{v: 21}, + &errWalkTest{v: 22}, + }, + }, + &errWalkTest{ + v: 30, + cause: &errWalkTest{v: 31}, + }, + }, + }} + + assertFind := func(v int, comment string) { + if !testFind(err, v) { + t.Errorf("%d not found in the error: %s", v, comment) + } + } + assertNotFind := func(v int, comment string) { + if testFind(err, v) { + t.Errorf("%d found in the error, but not expected: %s", v, comment) + } + } + + assertFind(1, "shallow search") + assertFind(11, "deep search A1") + assertFind(21, "deep search A2") + assertFind(22, "deep search B1") + assertNotFind(23, "deep search Neg") + assertFind(31, "deep search B2") + assertNotFind(32, "deep search Neg") + assertFind(30, "Tree node A") + assertFind(20, "Tree node with many children") +} + +type fooError int + +func (fooError) Error() string { + return "foo" +} + +func TestWorkWithStdErrors(t *testing.T) { + e1 := fooError(100) + e2 := Normalize("e2", RFCCodeText("e2")) + e3 := Normalize("e3", RFCCodeText("e3")) + e21 := e2.Wrap(e1) + e31 := e3.Wrap(e1) + e32 := e3.Wrap(e2) + e321 := e3.Wrap(e21) + + unwrapTbl := []struct { + x *Error // x.Unwrap() == y + y error + }{{e2, nil}, {e3, nil}, {e21, e1}, {e31, e1}, {e32, e2}, {e321, e21}} + for _, c := range unwrapTbl { + if c.x.Unwrap() != c.y { + t.Errorf("`%s`.Unwrap() != `%s`", c.x, c.y) + } + } + + isTbl := []struct { + x, y error // errors.Is(x, y) == b + b bool + }{ + {e1, e1, true}, {e2, e1, false}, {e3, e1, false}, {e21, e1, true}, {e321, e1, true}, + {e1, e2, false}, {e2, e2, true}, {e3, e2, false}, {e21, e2, true}, {e31, e2, false}, {e321, e2, true}, + {e2, e21, true}, {e21, e21, true}, {e31, e21, false}, {e321, e21, true}, + {e321, e321, true}, {e3, e321, true}, {e21, e321, false}, + } + for _, c := range isTbl { + if c.b && !errors.Is(c.x, c.y) { + t.Errorf("`%s` is not `%s`", c.x, c.y) + } + if !c.b && errors.Is(c.x, c.y) { + t.Errorf("`%s` is `%s`", c.x, c.y) + } + } + + var e1x fooError + if ok := errors.As(e21, &e1x); !ok { + t.Error("e21 cannot convert to e1") + } + if int(e1x) != 100 { + t.Error("e1x is not 100") + } + + var e2x *Error + if ok := errors.As(e21, &e2x); !ok { + t.Error("e21 cannot convert to e2") + } + if e2x.ID() != "e2" { + t.Error("err is not e2") + } + + e3x := e3.Wrap(e1) + if ok := errors.As(e21, &e3x); !ok { + t.Error("e21 cannot convert to e3") + } + if e3x.ID() != "e2" { + t.Error("err is not e2") + } +} + +func TestHasTrace(t *testing.T) { + targetErr := Normalize("test err") + require.False(t, HasStack(targetErr)) + require.False(t, HasStack(targetErr.FastGen("fast gen"))) + require.False(t, HasStack(targetErr.FastGenByArgs("fast gen arg"))) + require.True(t, HasStack(Trace(targetErr.FastGen("fast gen")))) + require.True(t, HasStack(targetErr.GenWithStack("gen"))) +} + +func TestGetErrStackMsg(t *testing.T) { + require.Equal(t, "", GetErrStackMsg(nil)) + + namedErr := Normalize("named err message", RFCCodeText("NamedError")) + require.False(t, HasStack(namedErr)) + require.Equal(t, "named err message", GetErrStackMsg(namedErr)) + tracedErr := Trace(namedErr) + require.Equal(t, "named err message", GetErrStackMsg(tracedErr)) + + annotatedErr := Annotate(tracedErr, "annotated message") + require.Equal(t, "annotated message: named err message", GetErrStackMsg(annotatedErr)) + + annotatedErr = Annotate(annotatedErr, "annotated message 2") + require.Equal(t, "annotated message 2: annotated message: named err message", GetErrStackMsg(annotatedErr)) + + fundErr := New("new fundamental error") + wrappedErr := namedErr.Wrap(fundErr) + require.Equal(t, "named err message: new fundamental error", GetErrStackMsg(wrappedErr)) + fastGen := wrappedErr.FastGen("fast gen") + require.Equal(t, "fast gen: new fundamental error", GetErrStackMsg(fastGen)) + + urlErr := &url.Error{Op: "GET", URL: "/url", Err: errors.New("internal golang err")} + fastGen = namedErr.Wrap(urlErr).FastGen("fast gen") + require.Equal(t, `fast gen: GET "/url": internal golang err`, GetErrStackMsg(fastGen)) +} diff --git a/example_test.go b/example_test.go index c1fc13e3..cc4edaa2 100644 --- a/example_test.go +++ b/example_test.go @@ -2,8 +2,7 @@ package errors_test import ( "fmt" - - "github.com/pkg/errors" + "github.com/pingcap/errors" ) func ExampleNew() { diff --git a/format_test.go b/format_test.go index c2eef5f0..10c9b155 100644 --- a/format_test.go +++ b/format_test.go @@ -26,8 +26,8 @@ func TestFormatNew(t *testing.T) { New("error"), "%+v", "error\n" + - "github.com/pkg/errors.TestFormatNew\n" + - "\t.+/github.com/pkg/errors/format_test.go:26", + "github.com/pingcap/errors.TestFormatNew\n" + + "\t.+/pingcap/errors/format_test.go:26", }, { New("error"), "%q", @@ -56,8 +56,8 @@ func TestFormatErrorf(t *testing.T) { Errorf("%s", "error"), "%+v", "error\n" + - "github.com/pkg/errors.TestFormatErrorf\n" + - "\t.+/github.com/pkg/errors/format_test.go:56", + "github.com/pingcap/errors.TestFormatErrorf\n" + + "\t.+/pingcap/errors/format_test.go:56", }} for i, tt := range tests { @@ -71,45 +71,45 @@ func TestFormatWrap(t *testing.T) { format string want string }{{ - Wrap(New("error"), "error2"), + Annotate(New("error"), "error2"), "%s", "error2: error", }, { - Wrap(New("error"), "error2"), + Annotate(New("error"), "error2"), "%v", "error2: error", }, { - Wrap(New("error"), "error2"), + Annotate(New("error"), "error2"), "%+v", "error\n" + - "github.com/pkg/errors.TestFormatWrap\n" + - "\t.+/github.com/pkg/errors/format_test.go:82", + "github.com/pingcap/errors.TestFormatWrap\n" + + "\t.+/pingcap/errors/format_test.go:82", }, { - Wrap(io.EOF, "error"), + Annotate(io.EOF, "error"), "%s", "error: EOF", }, { - Wrap(io.EOF, "error"), + Annotate(io.EOF, "error"), "%v", "error: EOF", }, { - Wrap(io.EOF, "error"), + Annotate(io.EOF, "error"), "%+v", "EOF\n" + "error\n" + - "github.com/pkg/errors.TestFormatWrap\n" + - "\t.+/github.com/pkg/errors/format_test.go:96", + "github.com/pingcap/errors.TestFormatWrap\n" + + "\t.+/pingcap/errors/format_test.go:96", }, { - Wrap(Wrap(io.EOF, "error1"), "error2"), + Annotate(Annotate(io.EOF, "error1"), "error2"), "%+v", "EOF\n" + "error1\n" + - "github.com/pkg/errors.TestFormatWrap\n" + - "\t.+/github.com/pkg/errors/format_test.go:103\n", + "github.com/pingcap/errors.TestFormatWrap\n" + + "\t.+/pingcap/errors/format_test.go:103\n", }, { - Wrap(New("error with space"), "context"), + Annotate(New("error with space"), "context"), "%q", - `"context: error with space"`, + `context: error with space`, }} for i, tt := range tests { @@ -123,34 +123,34 @@ func TestFormatWrapf(t *testing.T) { format string want string }{{ - Wrapf(io.EOF, "error%d", 2), + Annotatef(io.EOF, "error%d", 2), "%s", "error2: EOF", }, { - Wrapf(io.EOF, "error%d", 2), + Annotatef(io.EOF, "error%d", 2), "%v", "error2: EOF", }, { - Wrapf(io.EOF, "error%d", 2), + Annotatef(io.EOF, "error%d", 2), "%+v", "EOF\n" + "error2\n" + - "github.com/pkg/errors.TestFormatWrapf\n" + - "\t.+/github.com/pkg/errors/format_test.go:134", + "github.com/pingcap/errors.TestFormatWrapf\n" + + "\t.+/pingcap/errors/format_test.go:134", }, { - Wrapf(New("error"), "error%d", 2), + Annotatef(New("error"), "error%d", 2), "%s", "error2: error", }, { - Wrapf(New("error"), "error%d", 2), + Annotatef(New("error"), "error%d", 2), "%v", "error2: error", }, { - Wrapf(New("error"), "error%d", 2), + Annotatef(New("error"), "error%d", 2), "%+v", "error\n" + - "github.com/pkg/errors.TestFormatWrapf\n" + - "\t.+/github.com/pkg/errors/format_test.go:149", + "github.com/pingcap/errors.TestFormatWrapf\n" + + "\t.+/pingcap/errors/format_test.go:149", }} for i, tt := range tests { @@ -175,8 +175,8 @@ func TestFormatWithStack(t *testing.T) { WithStack(io.EOF), "%+v", []string{"EOF", - "github.com/pkg/errors.TestFormatWithStack\n" + - "\t.+/github.com/pkg/errors/format_test.go:175"}, + "github.com/pingcap/errors.TestFormatWithStack\n" + + "\t.+/pingcap/errors/format_test.go:175"}, }, { WithStack(New("error")), "%s", @@ -189,37 +189,37 @@ func TestFormatWithStack(t *testing.T) { WithStack(New("error")), "%+v", []string{"error", - "github.com/pkg/errors.TestFormatWithStack\n" + - "\t.+/github.com/pkg/errors/format_test.go:189", - "github.com/pkg/errors.TestFormatWithStack\n" + - "\t.+/github.com/pkg/errors/format_test.go:189"}, + "github.com/pingcap/errors.TestFormatWithStack\n" + + "\t.+/pingcap/errors/format_test.go:189", + "github.com/pingcap/errors.TestFormatWithStack\n" + + "\t.+/pingcap/errors/format_test.go:189"}, }, { WithStack(WithStack(io.EOF)), "%+v", []string{"EOF", - "github.com/pkg/errors.TestFormatWithStack\n" + - "\t.+/github.com/pkg/errors/format_test.go:197", - "github.com/pkg/errors.TestFormatWithStack\n" + - "\t.+/github.com/pkg/errors/format_test.go:197"}, + "github.com/pingcap/errors.TestFormatWithStack\n" + + "\t.+/pingcap/errors/format_test.go:197", + "github.com/pingcap/errors.TestFormatWithStack\n" + + "\t.+/pingcap/errors/format_test.go:197"}, }, { - WithStack(WithStack(Wrapf(io.EOF, "message"))), + WithStack(WithStack(Annotatef(io.EOF, "message"))), "%+v", []string{"EOF", "message", - "github.com/pkg/errors.TestFormatWithStack\n" + - "\t.+/github.com/pkg/errors/format_test.go:205", - "github.com/pkg/errors.TestFormatWithStack\n" + - "\t.+/github.com/pkg/errors/format_test.go:205", - "github.com/pkg/errors.TestFormatWithStack\n" + - "\t.+/github.com/pkg/errors/format_test.go:205"}, + "github.com/pingcap/errors.TestFormatWithStack\n" + + "\t.+/pingcap/errors/format_test.go:205", + "github.com/pingcap/errors.TestFormatWithStack\n" + + "\t.+/pingcap/errors/format_test.go:205", + "github.com/pingcap/errors.TestFormatWithStack\n" + + "\t.+/pingcap/errors/format_test.go:205"}, }, { WithStack(Errorf("error%d", 1)), "%+v", []string{"error1", - "github.com/pkg/errors.TestFormatWithStack\n" + - "\t.+/github.com/pkg/errors/format_test.go:216", - "github.com/pkg/errors.TestFormatWithStack\n" + - "\t.+/github.com/pkg/errors/format_test.go:216"}, + "github.com/pingcap/errors.TestFormatWithStack\n" + + "\t.+/pingcap/errors/format_test.go:216", + "github.com/pingcap/errors.TestFormatWithStack\n" + + "\t.+/pingcap/errors/format_test.go:216"}, }} for i, tt := range tests { @@ -245,8 +245,8 @@ func TestFormatWithMessage(t *testing.T) { "%+v", []string{ "error", - "github.com/pkg/errors.TestFormatWithMessage\n" + - "\t.+/github.com/pkg/errors/format_test.go:244", + "github.com/pingcap/errors.TestFormatWithMessage\n" + + "\t.+/pingcap/errors/format_test.go:244", "error2"}, }, { WithMessage(io.EOF, "addition1"), @@ -269,36 +269,34 @@ func TestFormatWithMessage(t *testing.T) { "%+v", []string{"EOF", "addition1", "addition2"}, }, { - Wrap(WithMessage(io.EOF, "error1"), "error2"), + Annotate(WithMessage(io.EOF, "error1"), "error2"), "%+v", []string{"EOF", "error1", "error2", - "github.com/pkg/errors.TestFormatWithMessage\n" + - "\t.+/github.com/pkg/errors/format_test.go:272"}, + "github.com/pingcap/errors.TestFormatWithMessage\n" + + "\t.+/pingcap/errors/format_test.go:272"}, }, { WithMessage(Errorf("error%d", 1), "error2"), "%+v", []string{"error1", - "github.com/pkg/errors.TestFormatWithMessage\n" + - "\t.+/github.com/pkg/errors/format_test.go:278", + "github.com/pingcap/errors.TestFormatWithMessage\n" + + "\t.+/pingcap/errors/format_test.go:278", "error2"}, }, { WithMessage(WithStack(io.EOF), "error"), "%+v", []string{ "EOF", - "github.com/pkg/errors.TestFormatWithMessage\n" + - "\t.+/github.com/pkg/errors/format_test.go:285", + "github.com/pingcap/errors.TestFormatWithMessage\n" + + "\t.+/pingcap/errors/format_test.go:285", "error"}, }, { - WithMessage(Wrap(WithStack(io.EOF), "inside-error"), "outside-error"), + WithMessage(Annotate(WithStack(io.EOF), "inside-error"), "outside-error"), "%+v", []string{ "EOF", - "github.com/pkg/errors.TestFormatWithMessage\n" + - "\t.+/github.com/pkg/errors/format_test.go:293", + "github.com/pingcap/errors.TestFormatWithMessage\n" + + "\t.+/pingcap/errors/format_test.go:293", "inside-error", - "github.com/pkg/errors.TestFormatWithMessage\n" + - "\t.+/github.com/pkg/errors/format_test.go:293", "outside-error"}, }} @@ -307,19 +305,19 @@ func TestFormatWithMessage(t *testing.T) { } } -func TestFormatGeneric(t *testing.T) { +/*func TestFormatGeneric(t *testing.T) { starts := []struct { err error want []string }{ {New("new-error"), []string{ "new-error", - "github.com/pkg/errors.TestFormatGeneric\n" + - "\t.+/github.com/pkg/errors/format_test.go:315"}, + "github.com/pingcap/errors.TestFormatGeneric\n" + + "\t.+/github.com/pingcap/errors/format_test.go:313"}, }, {Errorf("errorf-error"), []string{ "errorf-error", - "github.com/pkg/errors.TestFormatGeneric\n" + - "\t.+/github.com/pkg/errors/format_test.go:319"}, + "github.com/pingcap/errors.TestFormatGeneric\n" + + "\t.+/github.com/pingcap/errors/format_test.go:317"}, }, {errors.New("errors-new-error"), []string{ "errors-new-error"}, }, @@ -332,22 +330,22 @@ func TestFormatGeneric(t *testing.T) { }, { func(err error) error { return WithStack(err) }, []string{ - "github.com/pkg/errors.(func·002|TestFormatGeneric.func2)\n\t" + - ".+/github.com/pkg/errors/format_test.go:333", + "github.com/pingcap/errors.(func·002|TestFormatGeneric.func2)\n\t" + + ".+/github.com/pingcap/errors/format_test.go:331", }, }, { - func(err error) error { return Wrap(err, "wrap-error") }, + func(err error) error { return Annotate(err, "wrap-error") }, []string{ "wrap-error", - "github.com/pkg/errors.(func·003|TestFormatGeneric.func3)\n\t" + - ".+/github.com/pkg/errors/format_test.go:339", + "github.com/pingcap/errors.(func·003|TestFormatGeneric.func3)\n\t" + + ".+/github.com/pingcap/errors/format_test.go:337", }, }, { - func(err error) error { return Wrapf(err, "wrapf-error%d", 1) }, + func(err error) error { return Annotatef(err, "wrapf-error%d", 1) }, []string{ "wrapf-error1", - "github.com/pkg/errors.(func·004|TestFormatGeneric.func4)\n\t" + - ".+/github.com/pkg/errors/format_test.go:346", + "github.com/pingcap/errors.(func·004|TestFormatGeneric.func4)\n\t" + + ".+/github.com/pingcap/errors/format_test.go:346", }, }, } @@ -358,9 +356,10 @@ func TestFormatGeneric(t *testing.T) { testFormatCompleteCompare(t, s, err, "%+v", want, false) testGenericRecursive(t, err, want, wrappers, 3) } -} +}*/ -func testFormatRegexp(t *testing.T, n int, arg interface{}, format, want string) { +func testFormatRegexp(t *testing.T, n int, arg any, format, want string) { + t.Helper() got := fmt.Sprintf(format, arg) gotLines := strings.SplitN(got, "\n", -1) wantLines := strings.SplitN(want, "\n", -1) @@ -384,22 +383,21 @@ func testFormatRegexp(t *testing.T, n int, arg interface{}, format, want string) var stackLineR = regexp.MustCompile(`\.`) // parseBlocks parses input into a slice, where: -// - incase entry contains a newline, its a stacktrace -// - incase entry contains no newline, its a solo line. +// - incase entry contains a newline, its a stacktrace +// - incase entry contains no newline, its a solo line. // // Detecting stack boundaries only works incase the WithStack-calls are // to be found on the same line, thats why it is optionally here. // // Example use: // -// for _, e := range blocks { -// if strings.ContainsAny(e, "\n") { -// // Match as stack -// } else { -// // Match as line -// } -// } -// +// for _, e := range blocks { +// if strings.ContainsAny(e, "\n") { +// // Match as stack +// } else { +// // Match as line +// } +// } func parseBlocks(input string, detectStackboundaries bool) ([]string, error) { var blocks []string @@ -407,7 +405,7 @@ func parseBlocks(input string, detectStackboundaries bool) ([]string, error) { wasStack := false lines := map[string]bool{} // already found lines - for _, l := range strings.Split(input, "\n") { + for l := range strings.SplitSeq(input, "\n") { isStackLine := stackLineR.MatchString(l) switch { @@ -453,7 +451,7 @@ func parseBlocks(input string, detectStackboundaries bool) ([]string, error) { return blocks, nil } -func testFormatCompleteCompare(t *testing.T, n int, arg interface{}, format string, want []string, detectStackBoundaries bool) { +func testFormatCompleteCompare(t *testing.T, n int, arg any, format string, want []string, detectStackBoundaries bool) { gotStr := fmt.Sprintf(format, arg) got, err := parseBlocks(gotStr, detectStackBoundaries) diff --git a/go.mod b/go.mod new file mode 100644 index 00000000..8423b18f --- /dev/null +++ b/go.mod @@ -0,0 +1,14 @@ +module github.com/pingcap/errors + +go 1.25 + +require ( + github.com/stretchr/testify v1.11.1 + go.uber.org/atomic v1.11.0 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 00000000..8d2f3332 --- /dev/null +++ b/go.sum @@ -0,0 +1,12 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/group.go b/group.go new file mode 100644 index 00000000..f07083bc --- /dev/null +++ b/group.go @@ -0,0 +1,49 @@ +package errors + +// ErrorGroup is an interface for multiple errors that are not a chain. +// This happens for example when executing multiple operations in parallel. +type ErrorGroup interface { + Errors() []error +} + +// Errors uses the ErrorGroup interface to return a slice of errors. +// If the ErrorGroup interface is not implemented it returns an array containing just the given error. +func Errors(err error) []error { + if eg, ok := err.(ErrorGroup); ok { + return eg.Errors() + } + return []error{err} +} + +// WalkDeep does a depth-first traversal of all errors. +// Any ErrorGroup is traversed (after going deep). +// The visitor function can return true to end the traversal early +// In that case, WalkDeep will return true, otherwise false. +func WalkDeep(err error, visitor func(err error) bool) bool { + if err == nil { + return false + } + + if visitor(err) { + return true + } + + // Go deep + unErr := Unwrap(err) + if unErr != nil { + if WalkDeep(unErr, visitor) { + return true + } + } + + // Go wide + if group, ok := err.(ErrorGroup); ok { + for _, err := range group.Errors() { + if early := WalkDeep(err, visitor); early { + return true + } + } + } + + return false +} diff --git a/join.go b/join.go new file mode 100644 index 00000000..af587b1e --- /dev/null +++ b/join.go @@ -0,0 +1,62 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package errors + +// Join returns an error that wraps the given errors. +// Any nil error values are discarded. +// Join returns nil if every value in errs is nil. +// The error formats as the concatenation of the strings obtained +// by calling the Error method of each element of errs, with a newline +// between each string. +// +// A non-nil error returned by Join implements the Unwrap() []error method. +func Join(errs ...error) error { + n := 0 + for _, err := range errs { + if err != nil { + n++ + } + } + if n == 0 { + return nil + } + e := &joinError{ + errs: make([]error, 0, n), + } + for _, err := range errs { + if err != nil { + e.errs = append(e.errs, err) + } + } + return e +} + +type joinError struct { + errs []error +} + +func (e *joinError) Error() string { + var b []byte + for i, err := range e.errs { + if i > 0 { + b = append(b, '\n') + } + b = append(b, err.Error()...) + } + return string(b) +} + +func (e *joinError) Unwrap() []error { + return e.errs +} diff --git a/join_test.go b/join_test.go new file mode 100644 index 00000000..56a44160 --- /dev/null +++ b/join_test.go @@ -0,0 +1,80 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package errors + +import ( + "reflect" + "testing" +) + +func TestJoinReturnsNil(t *testing.T) { + if err := Join(); err != nil { + t.Errorf("Join() = %v, want nil", err) + } + if err := Join(nil); err != nil { + t.Errorf("Join(nil) = %v, want nil", err) + } + if err := Join(nil, nil); err != nil { + t.Errorf("Join(nil, nil) = %v, want nil", err) + } +} + +func TestJoin(t *testing.T) { + err1 := New("err1") + err2 := New("err2") + for _, test := range []struct { + errs []error + want []error + }{{ + errs: []error{err1}, + want: []error{err1}, + }, { + errs: []error{err1, err2}, + want: []error{err1, err2}, + }, { + errs: []error{err1, nil, err2}, + want: []error{err1, err2}, + }} { + got := Join(test.errs...).(interface{ Unwrap() []error }).Unwrap() + if !reflect.DeepEqual(got, test.want) { + t.Errorf("Join(%v) = %v; want %v", test.errs, got, test.want) + } + if len(got) != cap(got) { + t.Errorf("Join(%v) returns errors with len=%v, cap=%v; want len==cap", test.errs, len(got), cap(got)) + } + } +} + +func TestJoinErrorMethod(t *testing.T) { + err1 := New("err1") + err2 := New("err2") + for _, test := range []struct { + errs []error + want string + }{{ + errs: []error{err1}, + want: "err1", + }, { + errs: []error{err1, err2}, + want: "err1\nerr2", + }, { + errs: []error{err1, nil, err2}, + want: "err1\nerr2", + }} { + got := Join(test.errs...).Error() + if got != test.want { + t.Errorf("Join(%v).Error() = %q; want %q", test.errs, got, test.want) + } + } +} diff --git a/juju_adaptor.go b/juju_adaptor.go new file mode 100644 index 00000000..aa20182a --- /dev/null +++ b/juju_adaptor.go @@ -0,0 +1,159 @@ +package errors + +import ( + "fmt" + "strings" +) + +// ==================== juju adaptor start ======================== + +// Trace just calls AddStack. +func Trace(err error) error { + if err == nil { + return nil + } + return AddStack(err) +} + +// Annotate adds a message and ensures there is a stack trace. +func Annotate(err error, message string) error { + if err == nil { + return nil + } + hasStack := HasStack(err) + err = &withMessage{ + cause: err, + msg: message, + causeHasStack: hasStack, + } + if hasStack { + return err + } + return &withStack{ + err, + callers(), + } +} + +// Annotatef adds a message and ensures there is a stack trace. +func Annotatef(err error, format string, args ...any) error { + if err == nil { + return nil + } + hasStack := HasStack(err) + err = &withMessage{ + cause: err, + msg: fmt.Sprintf(format, args...), + causeHasStack: hasStack, + } + if hasStack { + return err + } + return &withStack{ + err, + callers(), + } +} + +var emptyStack stack + +// NewNoStackError creates error without error stack +// later duplicate trace will no longer generate Stack too. +func NewNoStackError(msg string) error { + return &fundamental{ + msg: msg, + stack: &emptyStack, + } +} + +// NewNoStackErrorf creates error with error stack and formats according +// to a format specifier and returns the string as a value that satisfies error. +func NewNoStackErrorf(format string, args ...any) error { + return &fundamental{ + msg: fmt.Sprintf(format, args...), + stack: &emptyStack, + } +} + +// SuspendStack suspends stack generate for error. +// Deprecated, it's semantic is to clear the stack inside, we still allow upper +// layer to add stack again by using Trace. +// Sometimes we have very deep calling stack, the lower layer calls SuspendStack, +// but the upper layer want to add stack to it, if we disable adding stack permanently +// for an error, it's very hard to diagnose certain issues. +func SuspendStack(err error) error { + if err == nil { + return err + } + cleared := clearStack(err) + if cleared { + return err + } + return &withStack{ + err, + &emptyStack, + } +} + +func clearStack(err error) (cleared bool) { + switch typedErr := err.(type) { + case *withMessage: + return clearStack(typedErr.Cause()) + case *fundamental: + typedErr.stack = &emptyStack + return true + case *withStack: + typedErr.stack = &emptyStack + clearStack(typedErr.Cause()) + return true + default: + return false + } +} + +// ErrorStack will format a stack trace if it is available, otherwise it will be Error() +// If the error is nil, the empty string is returned +// Note that this just calls fmt.Sprintf("%+v", err) +func ErrorStack(err error) string { + if err == nil { + return "" + } + return fmt.Sprintf("%+v", err) +} + +// IsNotFound reports whether err was not found error. +func IsNotFound(err error) bool { + return strings.Contains(err.Error(), "not found") +} + +// NotFoundf represents an error with not found message. +func NotFoundf(format string, args ...any) error { + return Errorf(format+" not found", args...) +} + +// BadRequestf represents an error with bad request message. +func BadRequestf(format string, args ...any) error { + return Errorf(format+" bad request", args...) +} + +// NotSupportedf represents an error with not supported message. +func NotSupportedf(format string, args ...any) error { + return Errorf(format+" not supported", args...) +} + +// NotValidf represents an error with not valid message. +func NotValidf(format string, args ...any) error { + return Errorf(format+" not valid", args...) +} + +// IsAlreadyExists reports whether err was already exists error. +func IsAlreadyExists(err error) bool { + return strings.Contains(err.Error(), "already exists") +} + +// AlreadyExistsf represents an error with already exists message. +func AlreadyExistsf(format string, args ...any) error { + return Errorf(format+" already exists", args...) +} + +// ==================== juju adaptor end ======================== diff --git a/normalize.go b/normalize.go new file mode 100644 index 00000000..31f69426 --- /dev/null +++ b/normalize.go @@ -0,0 +1,427 @@ +// Copyright 2020 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package errors + +import ( + "fmt" + "runtime" + "strconv" + + "go.uber.org/atomic" +) + +var _ fmt.Formatter = (*redactFormatter)(nil) + +// RedactLogEnabled defines whether the arguments of Error need to be redacted. +var RedactLogEnabled atomic.String + +const ( + RedactLogEnable string = "ON" + RedactLogDisable = "OFF" + RedactLogMarker = "MARKER" +) + +// ErrCode represents a specific error type in a error class. +// Same error code can be used in different error classes. +type ErrCode int + +// ErrCodeText is a textual error code that represents a specific error type in a error class. +type ErrCodeText string + +type ErrorID string +type RFCErrorCode string + +// HackedStr can provide a stable string snapshot for unsafe/ephemeral string sources. +// The method name intentionally uses FreezeStr instead of Clone so normal clone-style +// types are not accidentally treated as error-format arguments. +// During error construction, arguments implementing this interface are replaced +// with FreezeStr() so deferred error formatting does not observe later mutations. +type HackedStr interface { + FreezeStr() string +} + +// Error is the 'prototype' of a type of errors. +// Use DefineError to make a *Error: +// var ErrUnavailable = errors.Normalize("Region %d is unavailable", errors.RFCCodeText("Unavailable")) +// +// "throw" it at runtime: +// +// func Somewhat() error { +// ... +// if err != nil { +// // generate a stackful error use the message template at defining, +// // also see FastGen(it's stackless), GenWithStack(it uses custom message template). +// return ErrUnavailable.GenWithStackByArgs(region.ID) +// } +// } +// +// testing whether an error belongs to a prototype: +// +// if ErrUnavailable.Equal(err) { +// // handle this error. +// } +type Error struct { + code ErrCode + // codeText is the textual describe of the error code + codeText ErrCodeText + // message is a template of the description of this error. + // printf-style formatting is enabled. + message string + // redactArgsPos defines the positions of arguments in message that need to be redacted. + // And it is controlled by the global var RedactLogEnabled. + // For example, an original error is `Duplicate entry 'PRIMARY' for key 'key'`, + // when RedactLogEnabled is ON and redactArgsPos is [0, 1], the error is `Duplicate entry '?' for key '?'`. + // when RedactLogEnabled is MARKER and redactArgsPos is [0, 1], the error is `Duplicate entry '‹..›' for key '‹..›'`. + redactArgsPos []int + // Cause is used to warp some third party error. + cause error + args []any + file string + line int +} + +var _ messenger = (*Error)(nil) +var _ fmt.Formatter = (*Error)(nil) + +// Code returns the numeric code of this error. +// ID() will return textual error if there it is, +// when you just want to get the purely numeric error +// (e.g., for mysql protocol transmission.), this would be useful. +func (e *Error) Code() ErrCode { + return e.code +} + +// Code returns ErrorCode, by the RFC: +// +// The error code is a 3-tuple of abbreviated component name, error class and error code, +// joined by a colon like {Component}:{ErrorClass}:{InnerErrorCode}. +func (e *Error) RFCCode() RFCErrorCode { + return RFCErrorCode(e.ID()) +} + +// ID returns the ID of this error. +func (e *Error) ID() ErrorID { + if e.codeText != "" { + return ErrorID(e.codeText) + } + return ErrorID(strconv.Itoa(int(e.code))) +} + +// Location returns the location where the error is created, +// implements juju/errors locationer interface. +func (e *Error) Location() (file string, line int) { + return e.file, e.line +} + +// MessageTemplate returns the error message template of this error. +func (e *Error) MessageTemplate() string { + return e.message +} + +// Args returns the message arguments of this error. +func (e *Error) Args() []any { + return e.args +} + +// Error implements error interface. +func (e *Error) Error() string { + if e == nil { + return "" + } + if e.cause != nil { + return fmt.Sprintf("[%s]%s: %s", e.RFCCode(), e.GetMsg(), e.cause.Error()) + } + return fmt.Sprintf("[%s]%s", e.RFCCode(), e.GetMsg()) +} + +func (e *Error) Format(s fmt.State, verb rune) { + switch verb { + case 'v': + if s.Flag('+') { + if e != nil && e.cause != nil { + fmt.Fprintf(s, "%+v\n", e.cause) + fmt.Fprintf(s, "[%s]%s", e.RFCCode(), e.GetMsg()) + return + } + fmt.Fprint(s, e.Error()) + return + } + fallthrough + case 's': + fmt.Fprint(s, e.Error()) + case 'q': + fmt.Fprintf(s, "%q", e.Error()) + } +} + +func (e *Error) GetMsg() string { + if len(e.args) > 0 { + return fmt.Sprintf(e.message, e.args...) + } + return e.message +} + +func freezeHackedStringArgs(args []any) []any { + // This helper intentionally mutates the input slice in place. + // It is only used inside error-construction paths where args are internal and should + // not be reused by callers after passing into Gen*/FastGen* APIs. + for i := range args { + hackedArg, ok := args[i].(HackedStr) + if !ok { + continue + } + // TiDB may pass unsafe zero-copy strings (for example from chunk buffers) as error args. + // Error message rendering is deferred until GetMsg/Error, so without freezing here we may + // observe later writes and print a different value than the one used when creating the error. + // Freezing in this central path keeps the copy cost only on error construction. + args[i] = hackedArg.FreezeStr() + } + return args +} + +func (e *Error) GetSelfMsg() string { + return e.GetMsg() +} + +func (e *Error) fillLineAndFile(skip int) { + // skip this + _, file, line, ok := runtime.Caller(skip + 1) + if !ok { + e.file = "" + e.line = -1 + return + } + e.file = file + e.line = line +} + +// GenWithStack generates a new *Error with the same class and code, and a new formatted message. +func (e *Error) GenWithStack(format string, args ...any) error { + // TODO: RedactErrorArg + err := *e + err.message = format + err.args = freezeHackedStringArgs(args) + err.fillLineAndFile(1) + return AddStack(&err) +} + +// GenWithStackByArgs generates a new *Error with the same class and code, and new arguments. +func (e *Error) GenWithStackByArgs(args ...any) error { + RedactErrorArg(args, e.redactArgsPos) + err := *e + err.args = freezeHackedStringArgs(args) + err.fillLineAndFile(1) + return AddStack(&err) +} + +// FastGen generates a new *Error with the same class and code, and a new formatted message. +// This will not call runtime.Caller to get file and line. +func (e *Error) FastGen(format string, args ...any) error { + // TODO: RedactErrorArg + err := *e + err.message = format + err.args = freezeHackedStringArgs(args) + return SuspendStack(&err) +} + +// FastGen generates a new *Error with the same class and code, and a new arguments. +// This will not call runtime.Caller to get file and line. +func (e *Error) FastGenByArgs(args ...any) error { + RedactErrorArg(args, e.redactArgsPos) + err := *e + err.args = freezeHackedStringArgs(args) + return SuspendStack(&err) +} + +// Equal checks if err is equal to e. +func (e *Error) Equal(err error) bool { + originErr := Cause(err) + if originErr == nil { + return false + } + if error(e) == originErr { + return true + } + inErr, ok := originErr.(*Error) + if !ok { + return false + } + idEquals := e.ID() == inErr.ID() + return idEquals +} + +// NotEqual checks if err is not equal to e. +func (e *Error) NotEqual(err error) bool { + return !e.Equal(err) +} + +// RedactErrorArg redacts the args by position if RedactLogEnabled is enabled. +func RedactErrorArg(args []any, position []int) { + switch RedactLogEnabled.Load() { + case RedactLogEnable: + for _, pos := range position { + if len(args) > pos { + args[pos] = "?" + } + } + case RedactLogMarker: + for _, pos := range position { + if len(args) > pos { + args[pos] = &redactFormatter{args[pos]} + } + } + } +} + +// ErrorEqual returns a boolean indicating whether err1 is equal to err2. +func ErrorEqual(err1, err2 error) bool { + e1 := Cause(err1) + e2 := Cause(err2) + + if e1 == e2 { + return true + } + + if e1 == nil || e2 == nil { + return e1 == e2 + } + + te1, ok1 := e1.(*Error) + te2, ok2 := e2.(*Error) + if ok1 && ok2 { + return te1.Equal(te2) + } + + return e1.Error() == e2.Error() +} + +// ErrorNotEqual returns a boolean indicating whether err1 isn't equal to err2. +func ErrorNotEqual(err1, err2 error) bool { + return !ErrorEqual(err1, err2) +} + +type jsonError struct { + // Deprecated field, please use `RFCCode` instead. + Class int `json:"class"` + Code int `json:"code"` + Msg string `json:"message"` + RFCCode string `json:"rfccode"` +} + +func (e *Error) Wrap(err error) *Error { + if err != nil { + newErr := *e + newErr.cause = err + return &newErr + } + return nil +} + +// Unwrap returns cause of the error. +// It allows Error to work with errors.Is() and errors.As() from the Go +// standard package. +func (e *Error) Unwrap() error { + if e == nil { + return nil + } + return e.cause +} + +// Is checks if e has the same error ID with other. +// It allows Error to work with errors.Is() from the Go standard package. +func (e *Error) Is(other error) bool { + err, ok := other.(*Error) + if !ok { + return false + } + return (e == nil && err == nil) || (e != nil && err != nil && e.ID() == err.ID()) +} + +func (e *Error) Cause() error { + root := Unwrap(e.cause) + if root == nil { + return e.cause + } + return root +} + +func (e *Error) FastGenWithCause(args ...any) error { + err := *e + if e.cause != nil { + err.message = e.cause.Error() + } + err.args = freezeHackedStringArgs(args) + return SuspendStack(&err) +} + +func (e *Error) GenWithStackByCause(args ...any) error { + err := *e + if e.cause != nil { + err.message = e.cause.Error() + } + err.args = freezeHackedStringArgs(args) + err.fillLineAndFile(1) + return AddStack(&err) +} + +type NormalizeOption func(*Error) + +func RedactArgs(pos []int) NormalizeOption { + return func(e *Error) { + e.redactArgsPos = pos + } +} + +// RFCCodeText returns a NormalizeOption to set RFC error code. +func RFCCodeText(codeText string) NormalizeOption { + return func(e *Error) { + e.codeText = ErrCodeText(codeText) + } +} + +// MySQLErrorCode returns a NormalizeOption to set error code. +func MySQLErrorCode(code int) NormalizeOption { + return func(e *Error) { + e.code = ErrCode(code) + } +} + +// Normalize creates a new Error object. +func Normalize(message string, opts ...NormalizeOption) *Error { + e := &Error{ + message: message, + } + for _, opt := range opts { + opt(e) + } + return e +} + +type redactFormatter struct { + arg any +} + +func (e *redactFormatter) Format(f fmt.State, verb rune) { + origin := fmt.Sprintf(fmt.FormatString(f, verb), e.arg) + fmt.Fprintf(f, "‹") + for _, c := range origin { + if c == '‹' || c == '›' { + fmt.Fprintf(f, "%c", c) + fmt.Fprintf(f, "%c", c) + } else { + fmt.Fprintf(f, "%c", c) + } + } + fmt.Fprintf(f, "›") +} diff --git a/normalize_test.go b/normalize_test.go new file mode 100644 index 00000000..89557c1c --- /dev/null +++ b/normalize_test.go @@ -0,0 +1,200 @@ +package errors + +import ( + stderrors "errors" + "fmt" + "regexp" + "strings" + "testing" + "unsafe" +) + +type hackedStringArg struct { + raw []byte +} + +func (h hackedStringArg) FreezeStr() string { + return string(append([]byte(nil), h.raw...)) +} + +func errorMatches(t *testing.T, err error, re string) { + if err == nil && re != "" { + t.Errorf("nil error doesn't match %s", re) + return + } + match, reErr := regexp.MatchString(re, err.Error()) + if reErr != nil { + t.Errorf("invalid regexp %s (%s)", re, reErr.Error()) + return + } + if !match { + t.Errorf("error %s doesn't match %s", err.Error(), re) + return + } + t.Logf("passed: %s ~= %s", err.Error(), re) +} + +func TestCauseInErrorMessage(t *testing.T) { + errTest := Normalize("this error just for testing", RFCCodeText("Internal:Test")) + + wrapped := errTest.Wrap(New("everything is alright :)")) + errorMatches(t, wrapped, `\[Internal:Test\]this error just for testing: everything is alright :\)`) + + notWrapped := errTest.GenWithStack("everything is alright") + errorMatches(t, notWrapped, `^\[Internal:Test\]everything is alright$`) +} + +func TestWrappedNamedErrorGenWithStackByArgsFormatsCauseStack(t *testing.T) { + errTest := Normalize("named error: %s", RFCCodeText("Internal:Test")) + + err := errTest.Wrap(New("cause error")).GenWithStackByArgs("wrapped") + + if _, ok := err.(fmt.Formatter); !ok { + t.Fatalf("zap requires fmt.Formatter to emit errorVerbose for stackful errors, got %T", err) + } + + formatted := fmt.Sprintf("%+v", err) + if !strings.Contains(formatted, "github.com/pingcap/errors.TestWrappedNamedErrorGenWithStackByArgsFormatsCauseStack") { + t.Fatalf("formatted error does not contain the wrapped cause stack:\n%s", formatted) + } + if !strings.Contains(formatted, "[Internal:Test]named error: wrapped") { + t.Fatalf("formatted error does not contain named error context:\n%s", formatted) + } +} + +func TestWrappedNamedErrorFormatsStacklessCause(t *testing.T) { + errTest := Normalize("named error: %s", RFCCodeText("Internal:Test")) + + err := errTest.Wrap(stderrors.New("plain cause")).GenWithStackByArgs("wrapped") + + formatted := fmt.Sprintf("%+v", err) + wantPrefix := "plain cause\n[Internal:Test]named error: wrapped\n" + if !strings.HasPrefix(formatted, wantPrefix) { + t.Fatalf("unexpected formatted error prefix:\ngot: %q\nwant prefix: %q", formatted, wantPrefix) + } + if !strings.Contains(formatted, "github.com/pingcap/errors.TestWrappedNamedErrorFormatsStacklessCause") { + t.Fatalf("formatted error does not contain the generated stack:\n%s", formatted) + } +} + +func TestRedactFormatter(t *testing.T) { + rv := 34.03498 + v := &redactFormatter{rv} + for _, f := range []string{"%d", "%.2d"} { + a := fmt.Sprintf(f, v) + b := fmt.Sprintf("‹"+f+"›", rv) + if a != b { + t.Errorf("%s != %s", a, b) + } + } + + v = &redactFormatter{"‹"} + if a := fmt.Sprintf("%s", v); a != "‹‹‹›" { + t.Errorf("%s != <<<>", a) + } +} + +func TestGenWithStackByArgsNoCloneByDefault(t *testing.T) { + errTest := Normalize("Incorrect time value: '%s'", RFCCodeText("Internal:Test")) + + origin := []byte("120120519090607") + arg := *(*string)(unsafe.Pointer(&origin)) + err := errTest.GenWithStackByArgs(arg) + + copy(origin, "1 1:1:1.0000027") + got := err.(*withStack).error.(*Error).GetMsg() + want := "Incorrect time value: '1 1:1:1.0000027'" + if got != want { + t.Fatalf("message should track source bytes by default, got %q, want %q", got, want) + } +} + +func TestGenWithStackByArgsFreezeHackedStringArg(t *testing.T) { + errTest := Normalize("Incorrect time value: '%s'", RFCCodeText("Internal:Test")) + + origin := []byte("120120519090607") + arg := hackedStringArg{raw: origin} + err := errTest.GenWithStackByArgs(arg) + + copy(origin, "1 1:1:1.0000027") + got := err.(*withStack).error.(*Error).GetMsg() + want := "Incorrect time value: '120120519090607'" + if got != want { + t.Fatalf("message changed after source bytes mutated, got %q, want %q", got, want) + } +} + +func TestFastGenByArgsFreezeHackedStringArg(t *testing.T) { + errTest := Normalize("Incorrect time value: '%s'", RFCCodeText("Internal:Test")) + + origin := []byte("120120519090607") + arg := hackedStringArg{raw: origin} + err := errTest.FastGenByArgs(arg) + + copy(origin, "1 1:1:1.0000027") + got := err.(*withStack).error.(*Error).GetMsg() + want := "Incorrect time value: '120120519090607'" + if got != want { + t.Fatalf("message changed after source bytes mutated, got %q, want %q", got, want) + } +} + +func TestGenWithStackFreezeHackedStringArg(t *testing.T) { + errTest := Normalize("Incorrect time value: '%s'", RFCCodeText("Internal:Test")) + + origin := []byte("120120519090607") + arg := hackedStringArg{raw: origin} + err := errTest.GenWithStack("Incorrect time value: '%s'", arg) + + copy(origin, "1 1:1:1.0000027") + got := err.(*withStack).error.(*Error).GetMsg() + want := "Incorrect time value: '120120519090607'" + if got != want { + t.Fatalf("message changed after source bytes mutated, got %q, want %q", got, want) + } +} + +func TestFastGenFreezeHackedStringArg(t *testing.T) { + errTest := Normalize("Incorrect time value: '%s'", RFCCodeText("Internal:Test")) + + origin := []byte("120120519090607") + arg := hackedStringArg{raw: origin} + err := errTest.FastGen("Incorrect time value: '%s'", arg) + + copy(origin, "1 1:1:1.0000027") + got := err.(*withStack).error.(*Error).GetMsg() + want := "Incorrect time value: '120120519090607'" + if got != want { + t.Fatalf("message changed after source bytes mutated, got %q, want %q", got, want) + } +} + +func TestGenWithStackByCauseFreezeHackedStringArg(t *testing.T) { + errTest := Normalize("Incorrect time value: '%s'", RFCCodeText("Internal:Test")) + + origin := []byte("120120519090607") + arg := hackedStringArg{raw: origin} + err := errTest.GenWithStackByCause(arg) + + copy(origin, "1 1:1:1.0000027") + got := err.(*withStack).error.(*Error).GetMsg() + want := "Incorrect time value: '120120519090607'" + if got != want { + t.Fatalf("message changed after source bytes mutated, got %q, want %q", got, want) + } +} + +func TestFastGenWithCauseFreezeHackedStringArg(t *testing.T) { + errTest := Normalize("Incorrect time value: '%s'", RFCCodeText("Internal:Test")) + + origin := []byte("120120519090607") + arg := hackedStringArg{raw: origin} + err := errTest.FastGenWithCause(arg) + + copy(origin, "1 1:1:1.0000027") + got := err.(*withStack).error.(*Error).GetMsg() + want := "Incorrect time value: '120120519090607'" + if got != want { + t.Fatalf("message changed after source bytes mutated, got %q, want %q", got, want) + } +} diff --git a/stack.go b/stack.go index 2874a048..2263ba37 100644 --- a/stack.go +++ b/stack.go @@ -1,13 +1,40 @@ package errors import ( + "bytes" "fmt" "io" "path" "runtime" + "strconv" "strings" ) +// StackTracer retrieves the StackTrace +// Generally you would want to use the GetStackTracer function to do that. +type StackTracer interface { + StackTrace() StackTrace + // Empty returns true if the stack trace is empty, StackTrace might clone the + // stack trace, add this method to avoid unnecessary clone. + Empty() bool +} + +// GetStackTracer will return the first StackTracer in the causer chain. +// This function is used by AddStack to avoid creating redundant stack traces. +// +// You can also use the StackTracer interface on the returned error to get the stack trace. +func GetStackTracer(origErr error) StackTracer { + var stacked StackTracer + WalkDeep(origErr, func(err error) bool { + if stackTracer, ok := err.(StackTracer); ok { + stacked = stackTracer + return true + } + return false + }) + return stacked +} + // Frame represents a program counter inside a stack frame. type Frame uintptr @@ -39,17 +66,22 @@ func (f Frame) line() int { // Format formats the frame according to the fmt.Formatter interface. // -// %s source file -// %d source line -// %n function name -// %v equivalent to %s:%d +// %s source file +// %d source line +// %n function name +// %v equivalent to %s:%d // // Format accepts flags that alter the printing of some verbs, as follows: // -// %+s function name and path of source file relative to the compile time -// GOPATH separated by \n\t (\n\t) -// %+v equivalent to %+s:%d +// %+s function name and path of source file relative to the compile time +// GOPATH separated by \n\t (\n\t) +// %+v equivalent to %+s:%d func (f Frame) Format(s fmt.State, verb rune) { + f.format(s, s, verb) +} + +// format allows stack trace printing calls to be made with a bytes.Buffer. +func (f Frame) format(w io.Writer, s fmt.State, verb rune) { switch verb { case 's': switch { @@ -57,23 +89,25 @@ func (f Frame) Format(s fmt.State, verb rune) { pc := f.pc() fn := runtime.FuncForPC(pc) if fn == nil { - io.WriteString(s, "unknown") + io.WriteString(w, "unknown") } else { file, _ := fn.FileLine(pc) - fmt.Fprintf(s, "%s\n\t%s", fn.Name(), file) + io.WriteString(w, fn.Name()) + io.WriteString(w, "\n\t") + io.WriteString(w, file) } default: - io.WriteString(s, path.Base(f.file())) + io.WriteString(w, path.Base(f.file())) } case 'd': - fmt.Fprintf(s, "%d", f.line()) + io.WriteString(w, strconv.Itoa(f.line())) case 'n': name := runtime.FuncForPC(f.pc()).Name() - io.WriteString(s, funcname(name)) + io.WriteString(w, funcname(name)) case 'v': - f.Format(s, 's') - io.WriteString(s, ":") - f.Format(s, 'd') + f.format(w, s, 's') + io.WriteString(w, ":") + f.format(w, s, 'd') } } @@ -82,30 +116,57 @@ type StackTrace []Frame // Format formats the stack of Frames according to the fmt.Formatter interface. // -// %s lists source files for each Frame in the stack -// %v lists the source file and line number for each Frame in the stack +// %s lists source files for each Frame in the stack +// %v lists the source file and line number for each Frame in the stack // // Format accepts flags that alter the printing of some verbs, as follows: // -// %+v Prints filename, function, and line number for each Frame in the stack. +// %+v Prints filename, function, and line number for each Frame in the stack. func (st StackTrace) Format(s fmt.State, verb rune) { + var b bytes.Buffer switch verb { case 'v': switch { case s.Flag('+'): - for _, f := range st { - fmt.Fprintf(s, "\n%+v", f) + b.Grow(len(st) * stackMinLen) + for _, fr := range st { + b.WriteByte('\n') + fr.format(&b, s, verb) } case s.Flag('#'): - fmt.Fprintf(s, "%#v", []Frame(st)) + fmt.Fprintf(&b, "%#v", []Frame(st)) default: - fmt.Fprintf(s, "%v", []Frame(st)) + st.formatSlice(&b, s, verb) } case 's': - fmt.Fprintf(s, "%s", []Frame(st)) + st.formatSlice(&b, s, verb) + } + io.Copy(s, &b) +} + +// formatSlice will format this StackTrace into the given buffer as a slice of +// Frame, only valid when called with '%s' or '%v'. +func (st StackTrace) formatSlice(b *bytes.Buffer, s fmt.State, verb rune) { + b.WriteByte('[') + if len(st) == 0 { + b.WriteByte(']') + return } + + b.Grow(len(st) * (stackMinLen / 4)) + st[0].format(b, s, verb) + for _, fr := range st[1:] { + b.WriteByte(' ') + fr.format(b, s, verb) + } + b.WriteByte(']') } +// stackMinLen is a best-guess at the minimum length of a stack trace. It +// doesn't need to be exact, just give a good enough head start for the buffer +// to avoid the expensive early growth. +const stackMinLen = 96 + // stack represents a stack of program counters. type stack []uintptr @@ -114,26 +175,38 @@ func (s *stack) Format(st fmt.State, verb rune) { case 'v': switch { case st.Flag('+'): + var b bytes.Buffer + b.Grow(len(*s) * stackMinLen) for _, pc := range *s { f := Frame(pc) - fmt.Fprintf(st, "\n%+v", f) + b.WriteByte('\n') + f.format(&b, st, 'v') } + io.Copy(st, &b) } } } func (s *stack) StackTrace() StackTrace { f := make([]Frame, len(*s)) - for i := 0; i < len(f); i++ { + for i := range f { f[i] = Frame((*s)[i]) } return f } +func (s *stack) Empty() bool { + return len(*s) == 0 +} + func callers() *stack { + return callersSkip(4) +} + +func callersSkip(skip int) *stack { const depth = 32 var pcs [depth]uintptr - n := runtime.Callers(3, pcs[:]) + n := runtime.Callers(skip, pcs[:]) var st stack = pcs[0:n] return &st } @@ -145,3 +218,16 @@ func funcname(name string) string { i = strings.Index(name, ".") return name[i+1:] } + +// NewStack is for library implementers that want to generate a stack trace. +// Normally you should insted use AddStack to get an error with a stack trace. +// +// The result of this function can be turned into a stack trace by calling .StackTrace() +// +// This function takes an argument for the number of stack frames to skip. +// This avoids putting stack generation function calls like this one in the stack trace. +// A value of 0 will give you the line that called NewStack(0) +// A library author wrapping this in their own function will want to use a value of at least 1. +func NewStack(skip int) StackTracer { + return callersSkip(skip + 3) +} diff --git a/stack_test.go b/stack_test.go index 85fc4195..d6d855ee 100644 --- a/stack_test.go +++ b/stack_test.go @@ -2,7 +2,9 @@ package errors import ( "fmt" + "io" "runtime" + "strings" "testing" ) @@ -14,23 +16,23 @@ func TestFrameLine(t *testing.T) { want int }{{ Frame(initpc), - 9, + 11, }, { func() Frame { var pc, _, _, _ = runtime.Caller(0) return Frame(pc) }(), - 20, - }, { - func() Frame { - var pc, _, _, _ = runtime.Caller(1) - return Frame(pc) - }(), - 28, - }, { - Frame(0), // invalid PC - 0, - }} + 22, + }, /* { // TODO stdlib `runtime` Behavior changed between 1.13 and 1.14 + func() Frame { + var pc, _, _, _ = runtime.Caller(1) + return Frame(pc) + }(), + 24, + }, */{ + Frame(0), // invalid PC + 0, + }} for _, tt := range tests { got := tt.Frame.line() @@ -65,8 +67,8 @@ func TestFrameFormat(t *testing.T) { }, { Frame(initpc), "%+s", - "github.com/pkg/errors.init\n" + - "\t.+/github.com/pkg/errors/stack_test.go", + "github.com/pingcap/errors.init\n" + + "\t.+/pingcap/errors/stack_test.go", }, { Frame(0), "%s", @@ -78,7 +80,7 @@ func TestFrameFormat(t *testing.T) { }, { Frame(initpc), "%d", - "9", + "11", }, { Frame(0), "%d", @@ -108,12 +110,12 @@ func TestFrameFormat(t *testing.T) { }, { Frame(initpc), "%v", - "stack_test.go:9", + "stack_test.go:11", }, { Frame(initpc), "%+v", - "github.com/pkg/errors.init\n" + - "\t.+/github.com/pkg/errors/stack_test.go:9", + "github.com/pingcap/errors.init\n" + + "\t.+/pingcap/errors/stack_test.go:11", }, { Frame(0), "%v", @@ -131,7 +133,7 @@ func TestFuncname(t *testing.T) { }{ {"", ""}, {"runtime.main", "main"}, - {"github.com/pkg/errors.funcname", "funcname"}, + {"github.com/pingcap/errors.funcname", "funcname"}, {"funcname", "funcname"}, {"io.copyBuffer", "copyBuffer"}, {"main.(*R).Write", "(*R).Write"}, @@ -152,25 +154,25 @@ func TestStackTrace(t *testing.T) { want []string }{{ New("ooh"), []string{ - "github.com/pkg/errors.TestStackTrace\n" + - "\t.+/github.com/pkg/errors/stack_test.go:154", + "github.com/pingcap/errors.TestStackTrace\n" + + "\t.+/pingcap/errors/stack_test.go", }, }, { - Wrap(New("ooh"), "ahh"), []string{ - "github.com/pkg/errors.TestStackTrace\n" + - "\t.+/github.com/pkg/errors/stack_test.go:159", // this is the stack of Wrap, not New + Annotate(New("ooh"), "ahh"), []string{ + "github.com/pingcap/errors.TestStackTrace\n" + + "\t.+/pingcap/errors/stack_test.go", // this is the stack of Wrap, not New }, }, { - Cause(Wrap(New("ooh"), "ahh")), []string{ - "github.com/pkg/errors.TestStackTrace\n" + - "\t.+/github.com/pkg/errors/stack_test.go:164", // this is the stack of New + Cause(Annotate(New("ooh"), "ahh")), []string{ + "github.com/pingcap/errors.TestStackTrace\n" + + "\t.+/pingcap/errors/stack_test.go", // this is the stack of New }, }, { func() error { return New("ooh") }(), []string{ - `github.com/pkg/errors.(func·009|TestStackTrace.func1)` + - "\n\t.+/github.com/pkg/errors/stack_test.go:169", // this is the stack of New - "github.com/pkg/errors.TestStackTrace\n" + - "\t.+/github.com/pkg/errors/stack_test.go:169", // this is the stack of New's caller + `github.com/pingcap/errors.(func·009|TestStackTrace.func1)` + + "\n\t.+/pingcap/errors/stack_test.go", // this is the stack of New + "github.com/pingcap/errors.TestStackTrace\n" + + "\t.+/pingcap/errors/stack_test.go", // this is the stack of New's caller }, }, { Cause(func() error { @@ -178,35 +180,40 @@ func TestStackTrace(t *testing.T) { return Errorf("hello %s", fmt.Sprintf("world")) }() }()), []string{ - `github.com/pkg/errors.(func·010|TestStackTrace.func2.1)` + - "\n\t.+/github.com/pkg/errors/stack_test.go:178", // this is the stack of Errorf - `github.com/pkg/errors.(func·011|TestStackTrace.func2)` + - "\n\t.+/github.com/pkg/errors/stack_test.go:179", // this is the stack of Errorf's caller - "github.com/pkg/errors.TestStackTrace\n" + - "\t.+/github.com/pkg/errors/stack_test.go:180", // this is the stack of Errorf's caller's caller + // go 1.23 when Debug its suffix is "TestStackTrace.func2.1", it's + // "TestStackTrace.TestStackTrace.func2.func3" when Run + `github.com/pingcap/errors.(func·010|TestStackTrace.func2.1|TestStackTrace.TestStackTrace.func2.func3)` + + "\n\t.+/pingcap/errors/stack_test.go", // this is the stack of Errorf + `github.com/pingcap/errors.(func·011|TestStackTrace.func2)` + + "\n\t.+/pingcap/errors/stack_test.go", // this is the stack of Errorf's caller + "github.com/pingcap/errors.TestStackTrace\n" + + "\t.+/pingcap/errors/stack_test.go", // this is the stack of Errorf's caller's caller }, }} for i, tt := range tests { - x, ok := tt.err.(interface { + ste, ok := tt.err.(interface { StackTrace() StackTrace }) if !ok { - t.Errorf("expected %#v to implement StackTrace() StackTrace", tt.err) - continue + ste = tt.err.(interface { + Cause() error + }).Cause().(interface { + StackTrace() StackTrace + }) } - st := x.StackTrace() + st := ste.StackTrace() for j, want := range tt.want { testFormatRegexp(t, i, st[j], "%+v", want) } } } +// This comment helps to maintain original line numbers +// Perhaps this test is too fragile :) func stackTrace() StackTrace { - const depth = 8 - var pcs [depth]uintptr - n := runtime.Callers(1, pcs[:]) - var st stack = pcs[0:n] - return st.StackTrace() + return NewStack(0).StackTrace() + // This comment helps to maintain original line numbers + // Perhaps this test is too fragile :) } func TestStackTraceFormat(t *testing.T) { @@ -253,22 +260,82 @@ func TestStackTraceFormat(t *testing.T) { }, { stackTrace()[:2], "%v", - `\[stack_test.go:207 stack_test.go:254\]`, + `[stack_test.go:207 stack_test.go:254]`, }, { stackTrace()[:2], "%+v", "\n" + - "github.com/pkg/errors.stackTrace\n" + - "\t.+/github.com/pkg/errors/stack_test.go:207\n" + - "github.com/pkg/errors.TestStackTraceFormat\n" + - "\t.+/github.com/pkg/errors/stack_test.go:258", + "github.com/pingcap/errors.stackTrace\n" + + "\t.+/pingcap/errors/stack_test.go:\\d+\n" + + "github.com/pingcap/errors.TestStackTraceFormat\n" + + "\t.+/pingcap/errors/stack_test.go:\\d+", }, { stackTrace()[:2], "%#v", - `\[\]errors.Frame{stack_test.go:207, stack_test.go:266}`, + `\[\]errors.Frame{stack_test.go:\d+, stack_test.go:\d+}`, }} for i, tt := range tests { testFormatRegexp(t, i, tt.StackTrace, tt.format, tt.want) } } + +func TestNewStack(t *testing.T) { + got := NewStack(1).StackTrace() + want := NewStack(1).StackTrace() + if got[0] != want[0] { + t.Errorf("NewStack(remove NewStack): want: %v, got: %v", want, got) + } + gotFirst := fmt.Sprintf("%+v", got[0])[0:15] + if gotFirst != "testing.tRunner" { + t.Errorf("NewStack(): want: %v, got: %+v", "testing.tRunner", gotFirst) + } +} + +func TestNewNoStackError(t *testing.T) { + err := NewNoStackError("test error") + err = Trace(err) + err = Trace(err) + result := fmt.Sprintf("%+v", err) + if !strings.Contains(result, "test error") || + !strings.Contains(result, "pingcap/errors.TestNewNoStackError") { + t.Errorf("NewNoStackError(): want %s, got %v", "test error", result) + } +} + +func TestNewNoStackErrorf(t *testing.T) { + err := NewNoStackErrorf("test error %s", "yes") + err = Trace(err) + err = Trace(err) + result := fmt.Sprintf("%+v", err) + if !strings.Contains(result, "test error yes") || + !strings.Contains(result, "pingcap/errors.TestNewNoStackErrorf") { + t.Errorf("NewNoStackError(): want %s, got %v", "test error", result) + } +} + +func TestSuspendError(t *testing.T) { + err := io.EOF + err = SuspendStack(err) + err = Trace(err) + result := fmt.Sprintf("%+v", err) + if !strings.Contains(result, "EOF") || + !strings.Contains(result, "pingcap/errors.TestSuspendError") { + t.Errorf("NewNoStackError(): want %s, got %v", "EOF", result) + } + if io.EOF != Cause(err) { + t.Errorf("SuspendStackError can not got back origion error.") + } +} + +func TestSuspendTracedWithMessageError(t *testing.T) { + tracedErr := Trace(io.EOF) + tracedErr = WithStack(tracedErr) + tracedErr = WithMessage(tracedErr, "1") + tracedErr = SuspendStack(tracedErr) + tracedErr = Trace(tracedErr) + result := fmt.Sprintf("%+v", tracedErr) + if result != "EOF\n1" { + t.Errorf("NewNoStackError(): want %s, got %v", "EOF\n1", result) + } +} diff --git a/terror_test/terror_test.go b/terror_test/terror_test.go new file mode 100644 index 00000000..86272fff --- /dev/null +++ b/terror_test/terror_test.go @@ -0,0 +1,171 @@ +// Copyright 2020 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package terror_test + +import ( + "encoding/json" + "os" + "runtime" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/suite" + + "github.com/pingcap/errors" +) + +const ( + CodeMissConnectionID errors.ErrCode = 1 + CodeResultUndetermined errors.ErrCode = 2 + CodeExecResultIsEmpty errors.ErrCode = 3 +) + +type TErrorTestSuite struct { + suite.Suite +} + +func (s *TErrorTestSuite) TestErrCode() { + s.Equal(CodeMissConnectionID, errors.ErrCode(1)) + s.Equal(CodeResultUndetermined, errors.ErrCode(2)) + s.Equal(CodeExecResultIsEmpty, errors.ErrCode(3)) +} + +var predefinedErr = errors.Normalize("predefiend error", errors.MySQLErrorCode(123)) +var predefinedTextualErr = errors.Normalize("executor is taking vacation at %s", errors.RFCCodeText("executor:ExecutorAbsent")) + +func example() error { + err := call() + return errors.Trace(err) +} + +func call() error { + return predefinedErr.GenWithStack("error message:%s", "abc") +} + +func (s *TErrorTestSuite) TestJson() { + tmpErr := errors.Normalize("this is a test error", errors.RFCCodeText("ddl:-1"), errors.MySQLErrorCode(-1)) + buf, err := json.Marshal(tmpErr) + s.Nil(err) + var curTErr errors.Error + err = json.Unmarshal(buf, &curTErr) + s.Nil(err) + isEqual := tmpErr.Equal(&curTErr) + s.Equal(curTErr.Error(), tmpErr.Error()) + s.True(isEqual) +} + +func (s *TErrorTestSuite) TestTraceAndLocation() { + err := example() + stack := errors.ErrorStack(err) + lines := strings.Split(stack, "\n") + goroot := strings.ReplaceAll(runtime.GOROOT(), string(os.PathSeparator), "/") + var sysStack = 0 + for _, line := range lines { + if strings.Contains(line, goroot) { + sysStack++ + } + } + s.Equalf(11, len(lines)-(2*sysStack), "stack = \n%s", stack) + var containTerr bool + for _, v := range lines { + if strings.Contains(v, "terror_test.go") { + containTerr = true + break + } + } + s.True(containTerr) +} + +func (s *TErrorTestSuite) TestErrorEqual() { + e1 := errors.New("test error") + s.NotNil(e1) + + e2 := errors.Trace(e1) + s.NotNil(e2) + + e3 := errors.Trace(e2) + s.NotNil(e3) + + s.Equal(e1, errors.Cause(e2)) + s.Equal(e1, errors.Cause(e3)) + s.Equal(errors.Cause(e3), errors.Cause(e2)) + + e4 := errors.New("test error") + s.NotEqual(e1, errors.Cause(e4)) + + e5 := errors.Errorf("test error") + s.NotEqual(e1, errors.Cause(e5)) + + s.True(errors.ErrorEqual(e1, e2)) + s.True(errors.ErrorEqual(e1, e3)) + s.True(errors.ErrorEqual(e1, e4)) + s.True(errors.ErrorEqual(e1, e5)) + + var e6 error + + s.True(errors.ErrorEqual(nil, nil)) + s.True(errors.ErrorNotEqual(e1, e6)) +} + +func (s *TErrorTestSuite) TestNewError() { + today := time.Now().Weekday().String() + err := predefinedTextualErr.GenWithStackByArgs(today) + s.NotNil(err) + s.Equal("[executor:ExecutorAbsent]executor is taking vacation at "+today, err.Error()) +} + +func (s *TErrorTestSuite) TestRFCCode() { + c1err1 := errors.Normalize("nothing", errors.RFCCodeText("TestErr1:Err1")) + c2err2 := errors.Normalize("nothing", errors.RFCCodeText("TestErr2:Err2")) + s.Equal(errors.RFCErrorCode("TestErr1:Err1"), c1err1.RFCCode()) + s.Equal(errors.RFCErrorCode("TestErr2:Err2"), c2err2.RFCCode()) + + berr := errors.Normalize("nothing", errors.RFCCodeText("Blank:B1")) + s.Equal(errors.RFCErrorCode("Blank:B1"), berr.RFCCode()) +} + +func (s *TErrorTestSuite) TestLineAndFile() { + err := predefinedTextualErr.GenWithStackByArgs("everyday") + _, f, l, _ := runtime.Caller(0) + terr, ok := errors.Cause(err).(*errors.Error) + s.True(ok) + + file, line := terr.Location() + s.Equal(f, file) + s.Equal(l-1, line) + + err2 := predefinedTextualErr.GenWithStackByArgs("everyday and everywhere") + _, f2, l2, _ := runtime.Caller(0) + terr2, ok2 := errors.Cause(err2).(*errors.Error) + s.True(ok2) + file2, line2 := terr2.Location() + s.Equal(f2, file2) + s.Equal(l2-1, line2) +} + +func (s *TErrorTestSuite) TestWarpAndField() { + cause := errors.New("load from etcd meet error") + s.NotNil(cause) + + err := errors.Normalize("fail to get leader", errors.RFCCodeText("member:ErrGetLeader")) + errWithCause := errors.Annotate(err, cause.Error()) + s.NotNil(errWithCause) + + s.Equal("load from etcd meet error: [member:ErrGetLeader]fail to get leader", errWithCause.Error()) +} + +func TestExampleTestSuite(t *testing.T) { + suite.Run(t, new(TErrorTestSuite)) +}