From de4748c47c67392a57f250714509f590f68ad395 Mon Sep 17 00:00:00 2001 From: Gopher Robot Date: Wed, 1 Feb 2023 18:43:23 +0000 Subject: [PATCH 01/17] [release-branch.go1.20] go1.20 Change-Id: I156873d216ccb7d91e716b4348069df246b527b3 Reviewed-on: https://go-review.googlesource.com/c/go/+/464496 Run-TryBot: Gopher Robot Auto-Submit: Gopher Robot Reviewed-by: Matthew Dempsky TryBot-Result: Gopher Robot Reviewed-by: Michael Knyszek --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index 3faae45cd8798..83534e24796a8 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -go1.20rc3 \ No newline at end of file +go1.20 \ No newline at end of file From 7302f83d8733203aa23f056690d20a4adb949424 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Sat, 4 Feb 2023 12:00:20 +0700 Subject: [PATCH 02/17] [release-branch.go1.20] cmd/compile: remove constant arithmetic overflows during typecheck Since go1.19, these errors are already reported by types2 for any user's Go code. Compiler generated code, which looks like constant expression should be evaluated as non-constant semantic, which allows overflows. Fixes #58319 Change-Id: I6f0049a69bdb0a8d0d7a0db49c7badaa92598ea2 Reviewed-on: https://go-review.googlesource.com/c/go/+/466676 Reviewed-by: Keith Randall TryBot-Result: Gopher Robot Run-TryBot: Cuong Manh Le Reviewed-by: Keith Randall Reviewed-by: Tobias Klauser --- src/cmd/compile/internal/typecheck/const.go | 36 ++------------------- test/fixedbugs/issue58293.go | 13 ++++++++ 2 files changed, 15 insertions(+), 34 deletions(-) create mode 100644 test/fixedbugs/issue58293.go diff --git a/src/cmd/compile/internal/typecheck/const.go b/src/cmd/compile/internal/typecheck/const.go index edc399ffd74c7..6855f05b7b346 100644 --- a/src/cmd/compile/internal/typecheck/const.go +++ b/src/cmd/compile/internal/typecheck/const.go @@ -34,10 +34,7 @@ func roundFloat(v constant.Value, sz int64) constant.Value { // truncate float literal fv to 32-bit or 64-bit precision // according to type; return truncated value. func truncfltlit(v constant.Value, t *types.Type) constant.Value { - if t.IsUntyped() || overflow(v, t) { - // If there was overflow, simply continuing would set the - // value to Inf which in turn would lead to spurious follow-on - // errors. Avoid this by returning the existing value. + if t.IsUntyped() { return v } @@ -48,10 +45,7 @@ func truncfltlit(v constant.Value, t *types.Type) constant.Value { // precision, according to type; return truncated value. In case of // overflow, calls Errorf but does not truncate the input value. func trunccmplxlit(v constant.Value, t *types.Type) constant.Value { - if t.IsUntyped() || overflow(v, t) { - // If there was overflow, simply continuing would set the - // value to Inf which in turn would lead to spurious follow-on - // errors. Avoid this by returning the existing value. + if t.IsUntyped() { return v } @@ -251,7 +245,6 @@ func convertVal(v constant.Value, t *types.Type, explicit bool) constant.Value { switch { case t.IsInteger(): v = toint(v) - overflow(v, t) return v case t.IsFloat(): v = toflt(v) @@ -273,9 +266,6 @@ func tocplx(v constant.Value) constant.Value { func toflt(v constant.Value) constant.Value { if v.Kind() == constant.Complex { - if constant.Sign(constant.Imag(v)) != 0 { - base.Errorf("constant %v truncated to real", v) - } v = constant.Real(v) } @@ -284,9 +274,6 @@ func toflt(v constant.Value) constant.Value { func toint(v constant.Value) constant.Value { if v.Kind() == constant.Complex { - if constant.Sign(constant.Imag(v)) != 0 { - base.Errorf("constant %v truncated to integer", v) - } v = constant.Real(v) } @@ -321,25 +308,6 @@ func toint(v constant.Value) constant.Value { return constant.MakeInt64(1) } -// overflow reports whether constant value v is too large -// to represent with type t, and emits an error message if so. -func overflow(v constant.Value, t *types.Type) bool { - // v has already been converted - // to appropriate form for t. - if t.IsUntyped() { - return false - } - if v.Kind() == constant.Int && constant.BitLen(v) > ir.ConstPrec { - base.Errorf("integer too large") - return true - } - if ir.ConstOverflow(v, t) { - base.Errorf("constant %v overflows %v", types.FmtConst(v, false), t) - return true - } - return false -} - func tostr(v constant.Value) constant.Value { if v.Kind() == constant.Int { r := unicode.ReplacementChar diff --git a/test/fixedbugs/issue58293.go b/test/fixedbugs/issue58293.go new file mode 100644 index 0000000000000..58d550025341a --- /dev/null +++ b/test/fixedbugs/issue58293.go @@ -0,0 +1,13 @@ +// compile + +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package p + +var bar = f(13579) + +func f(x uint16) uint16 { + return x>>8 | x<<8 +} From 487be3f90bf65c06eb5f1f30aec30cd0e5b24f92 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Sun, 5 Feb 2023 14:33:32 +0700 Subject: [PATCH 03/17] [release-branch.go1.20] cmd/compile: fix inline static init arguments substitued tree Blank node must be ignored when building arguments substitued tree. Otherwise, it could be used to replace other blank node in left hand side of an assignment, causing an invalid IR node. Consider the following code: type S1 struct { s2 S2 } type S2 struct{} func (S2) Make() S2 { return S2{} } func (S1) Make() S1 { return S1{s2: S2{}.Make()} } var _ = S1{}.Make() After staticAssignInlinedCall, the assignment becomes: var _ = S1{s2: S2{}.Make()} and the arg substitued tree is "map[*ir.Name]ir.Node{_: S1{}}". Now, when doing static assignment, if there is any assignment to blank node, for example: _ := S2{} That blank node will be replaced with "S1{}": S1{} := S2{} So constructing an invalid IR which causes the ICE. Fixes #58335 Change-Id: I21b48357f669a7e02a7eb4325246aadc31f78fb9 Reviewed-on: https://go-review.googlesource.com/c/go/+/465098 Run-TryBot: Cuong Manh Le Auto-Submit: Cuong Manh Le TryBot-Result: Gopher Robot Reviewed-by: Keith Randall Reviewed-by: Keith Randall Reviewed-by: David Chase Reviewed-on: https://go-review.googlesource.com/c/go/+/466275 Reviewed-by: Than McIntosh --- src/cmd/compile/internal/staticinit/sched.go | 3 +++ test/fixedbugs/issue58325.go | 23 ++++++++++++++++++++ 2 files changed, 26 insertions(+) create mode 100644 test/fixedbugs/issue58325.go diff --git a/src/cmd/compile/internal/staticinit/sched.go b/src/cmd/compile/internal/staticinit/sched.go index bd1bf4114d334..2bfb5d79b2276 100644 --- a/src/cmd/compile/internal/staticinit/sched.go +++ b/src/cmd/compile/internal/staticinit/sched.go @@ -615,6 +615,9 @@ func (s *Schedule) staticAssignInlinedCall(l *ir.Name, loff int64, call *ir.Inli // Build tree with args substituted for params and try it. args := make(map[*ir.Name]ir.Node) for i, v := range as2init.Lhs { + if ir.IsBlank(v) { + continue + } args[v.(*ir.Name)] = as2init.Rhs[i] } r, ok := subst(as2body.Rhs[0], args) diff --git a/test/fixedbugs/issue58325.go b/test/fixedbugs/issue58325.go new file mode 100644 index 0000000000000..d37089c800f71 --- /dev/null +++ b/test/fixedbugs/issue58325.go @@ -0,0 +1,23 @@ +// compile + +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package p + +type S1 struct { + s2 S2 +} + +type S2 struct{} + +func (S2) Make() S2 { + return S2{} +} + +func (S1) Make() S1 { + return S1{s2: S2{}.Make()} +} + +var _ = S1{}.Make() From 90b06002c44e7fb8fd4b9227efd2d1423e21176b Mon Sep 17 00:00:00 2001 From: Matthew Dempsky Date: Thu, 1 Dec 2022 17:24:23 -0800 Subject: [PATCH 04/17] [release-branch.go1.20] cmd/compile/internal/noder: stop creating TUNION types In the types1 universe under the unified frontend, we never need to worry about type parameter constraints, so we only see pure interfaces. However, we might still see interfaces that contain union types, because of interfaces like "interface{ any | int }" (equivalent to just "any"). We can handle these without needing to actually represent type unions within types1 by simply mapping any union to "any". Fixes #58413. Change-Id: I5e4efcf0339edbb01f4035c54fb6fb1f9ddc0c65 Reviewed-on: https://go-review.googlesource.com/c/go/+/458619 Run-TryBot: Matthew Dempsky Reviewed-by: Keith Randall TryBot-Result: Gopher Robot Reviewed-by: Keith Randall (cherry picked from commit a7de684e1b6f460aae7d4dbf2568cb21130ec520) Reviewed-on: https://go-review.googlesource.com/c/go/+/466435 Reviewed-by: Than McIntosh Run-TryBot: David Chase --- src/cmd/compile/internal/noder/reader.go | 32 +++++++++++++++++++----- test/typeparam/issue52124.go | 4 ++- 2 files changed, 29 insertions(+), 7 deletions(-) diff --git a/src/cmd/compile/internal/noder/reader.go b/src/cmd/compile/internal/noder/reader.go index d03da27a4603a..189531959e555 100644 --- a/src/cmd/compile/internal/noder/reader.go +++ b/src/cmd/compile/internal/noder/reader.go @@ -517,13 +517,33 @@ func (r *reader) doTyp() *types.Type { } func (r *reader) unionType() *types.Type { - terms := make([]*types.Type, r.Len()) - tildes := make([]bool, len(terms)) - for i := range terms { - tildes[i] = r.Bool() - terms[i] = r.typ() + // In the types1 universe, we only need to handle value types. + // Impure interfaces (i.e., interfaces with non-trivial type sets + // like "int | string") can only appear as type parameter bounds, + // and this is enforced by the types2 type checker. + // + // However, type unions can still appear in pure interfaces if the + // type union is equivalent to "any". E.g., typeparam/issue52124.go + // declares variables with the type "interface { any | int }". + // + // To avoid needing to represent type unions in types1 (since we + // don't have any uses for that today anyway), we simply fold them + // to "any". As a consistency check, we still read the union terms + // to make sure this substitution is safe. + + pure := false + for i, n := 0, r.Len(); i < n; i++ { + _ = r.Bool() // tilde + term := r.typ() + if term.IsEmptyInterface() { + pure = true + } + } + if !pure { + base.Fatalf("impure type set used in value type") } - return types.NewUnion(terms, tildes) + + return types.Types[types.TINTER] } func (r *reader) interfaceType() *types.Type { diff --git a/test/typeparam/issue52124.go b/test/typeparam/issue52124.go index a113fc74441f1..07cba479821ee 100644 --- a/test/typeparam/issue52124.go +++ b/test/typeparam/issue52124.go @@ -6,7 +6,9 @@ package p -type I interface{ any | int } +type Any any + +type I interface{ Any | int } var ( X I = 42 From 9987cb6cf34cc893887ad2ecc9d832ee3c69c255 Mon Sep 17 00:00:00 2001 From: qmuntal Date: Fri, 27 Jan 2023 19:12:43 +0100 Subject: [PATCH 05/17] [release-branch.go1.20] time: update windows zoneinfo_abbrs zoneinfo_abbrs hasn't been updated since go 1.14, it's time to regenerate it. Fixes #58117. Change-Id: Ic156ae607c46f1f5a9408b1fc0b56de6c14a4ed4 Reviewed-on: https://go-review.googlesource.com/c/go/+/463838 Reviewed-by: Alex Brainman Run-TryBot: Quim Muntal TryBot-Result: Gopher Robot Reviewed-by: Bryan Mills Reviewed-by: Dmitri Shuralyov (cherry picked from commit 007d8f4db1f890f0d34018bb418bdc90ad4a8c35) Reviewed-on: https://go-review.googlesource.com/c/go/+/466436 Reviewed-by: Than McIntosh Reviewed-by: Quim Muntal Run-TryBot: David Chase --- src/time/zoneinfo_abbrs_windows.go | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/src/time/zoneinfo_abbrs_windows.go b/src/time/zoneinfo_abbrs_windows.go index 3294d0d78689b..139bda1accea9 100644 --- a/src/time/zoneinfo_abbrs_windows.go +++ b/src/time/zoneinfo_abbrs_windows.go @@ -16,10 +16,11 @@ var abbrs = map[string]abbr{ "Egypt Standard Time": {"EET", "EET"}, // Africa/Cairo "Morocco Standard Time": {"+00", "+01"}, // Africa/Casablanca "South Africa Standard Time": {"SAST", "SAST"}, // Africa/Johannesburg + "South Sudan Standard Time": {"CAT", "CAT"}, // Africa/Juba "Sudan Standard Time": {"CAT", "CAT"}, // Africa/Khartoum "W. Central Africa Standard Time": {"WAT", "WAT"}, // Africa/Lagos "E. Africa Standard Time": {"EAT", "EAT"}, // Africa/Nairobi - "Sao Tome Standard Time": {"GMT", "WAT"}, // Africa/Sao_Tome + "Sao Tome Standard Time": {"GMT", "GMT"}, // Africa/Sao_Tome "Libya Standard Time": {"EET", "EET"}, // Africa/Tripoli "Namibia Standard Time": {"CAT", "CAT"}, // Africa/Windhoek "Aleutian Standard Time": {"HST", "HDT"}, // America/Adak @@ -33,8 +34,8 @@ var abbrs = map[string]abbr{ "Venezuela Standard Time": {"-04", "-04"}, // America/Caracas "SA Eastern Standard Time": {"-03", "-03"}, // America/Cayenne "Central Standard Time": {"CST", "CDT"}, // America/Chicago - "Mountain Standard Time (Mexico)": {"MST", "MDT"}, // America/Chihuahua - "Central Brazilian Standard Time": {"-04", "-03"}, // America/Cuiaba + "Mountain Standard Time (Mexico)": {"CST", "CST"}, // America/Chihuahua + "Central Brazilian Standard Time": {"-04", "-04"}, // America/Cuiaba "Mountain Standard Time": {"MST", "MDT"}, // America/Denver "Greenland Standard Time": {"-03", "-02"}, // America/Godthab "Turks And Caicos Standard Time": {"EST", "EDT"}, // America/Grand_Turk @@ -44,7 +45,7 @@ var abbrs = map[string]abbr{ "US Eastern Standard Time": {"EST", "EDT"}, // America/Indianapolis "SA Western Standard Time": {"-04", "-04"}, // America/La_Paz "Pacific Standard Time": {"PST", "PDT"}, // America/Los_Angeles - "Central Standard Time (Mexico)": {"CST", "CDT"}, // America/Mexico_City + "Central Standard Time (Mexico)": {"CST", "CST"}, // America/Mexico_City "Saint Pierre Standard Time": {"-03", "-02"}, // America/Miquelon "Montevideo Standard Time": {"-03", "-03"}, // America/Montevideo "Eastern Standard Time": {"EST", "EDT"}, // America/New_York @@ -53,11 +54,12 @@ var abbrs = map[string]abbr{ "Magallanes Standard Time": {"-03", "-03"}, // America/Punta_Arenas "Canada Central Standard Time": {"CST", "CST"}, // America/Regina "Pacific SA Standard Time": {"-04", "-03"}, // America/Santiago - "E. South America Standard Time": {"-03", "-02"}, // America/Sao_Paulo + "E. South America Standard Time": {"-03", "-03"}, // America/Sao_Paulo "Newfoundland Standard Time": {"NST", "NDT"}, // America/St_Johns "Pacific Standard Time (Mexico)": {"PST", "PDT"}, // America/Tijuana + "Yukon Standard Time": {"MST", "MST"}, // America/Whitehorse "Central Asia Standard Time": {"+06", "+06"}, // Asia/Almaty - "Jordan Standard Time": {"EET", "EEST"}, // Asia/Amman + "Jordan Standard Time": {"+03", "+03"}, // Asia/Amman "Arabic Standard Time": {"+03", "+03"}, // Asia/Baghdad "Azerbaijan Standard Time": {"+04", "+04"}, // Asia/Baku "SE Asia Standard Time": {"+07", "+07"}, // Asia/Bangkok @@ -66,7 +68,7 @@ var abbrs = map[string]abbr{ "India Standard Time": {"IST", "IST"}, // Asia/Calcutta "Transbaikal Standard Time": {"+09", "+09"}, // Asia/Chita "Sri Lanka Standard Time": {"+0530", "+0530"}, // Asia/Colombo - "Syria Standard Time": {"EET", "EEST"}, // Asia/Damascus + "Syria Standard Time": {"+03", "+03"}, // Asia/Damascus "Bangladesh Standard Time": {"+06", "+06"}, // Asia/Dhaka "Arabian Standard Time": {"+04", "+04"}, // Asia/Dubai "West Bank Standard Time": {"EET", "EEST"}, // Asia/Hebron @@ -82,7 +84,7 @@ var abbrs = map[string]abbr{ "N. Central Asia Standard Time": {"+07", "+07"}, // Asia/Novosibirsk "Omsk Standard Time": {"+06", "+06"}, // Asia/Omsk "North Korea Standard Time": {"KST", "KST"}, // Asia/Pyongyang - "Qyzylorda Standard Time": {"+05", "+06"}, // Asia/Qyzylorda + "Qyzylorda Standard Time": {"+05", "+05"}, // Asia/Qyzylorda "Myanmar Standard Time": {"+0630", "+0630"}, // Asia/Rangoon "Arab Standard Time": {"+03", "+03"}, // Asia/Riyadh "Sakhalin Standard Time": {"+11", "+11"}, // Asia/Sakhalin @@ -93,7 +95,7 @@ var abbrs = map[string]abbr{ "Taipei Standard Time": {"CST", "CST"}, // Asia/Taipei "West Asia Standard Time": {"+05", "+05"}, // Asia/Tashkent "Georgian Standard Time": {"+04", "+04"}, // Asia/Tbilisi - "Iran Standard Time": {"+0330", "+0430"}, // Asia/Tehran + "Iran Standard Time": {"+0330", "+0330"}, // Asia/Tehran "Tokyo Standard Time": {"JST", "JST"}, // Asia/Tokyo "Tomsk Standard Time": {"+07", "+07"}, // Asia/Tomsk "Ulaanbaatar Standard Time": {"+08", "+08"}, // Asia/Ulaanbaatar @@ -112,7 +114,6 @@ var abbrs = map[string]abbr{ "Lord Howe Standard Time": {"+1030", "+11"}, // Australia/Lord_Howe "W. Australia Standard Time": {"AWST", "AWST"}, // Australia/Perth "AUS Eastern Standard Time": {"AEST", "AEDT"}, // Australia/Sydney - "UTC": {"GMT", "GMT"}, // Etc/GMT "UTC-11": {"-11", "-11"}, // Etc/GMT+11 "Dateline Standard Time": {"-12", "-12"}, // Etc/GMT+12 "UTC-02": {"-02", "-02"}, // Etc/GMT+2 @@ -120,6 +121,7 @@ var abbrs = map[string]abbr{ "UTC-09": {"-09", "-09"}, // Etc/GMT+9 "UTC+12": {"+12", "+12"}, // Etc/GMT-12 "UTC+13": {"+13", "+13"}, // Etc/GMT-13 + "UTC": {"UTC", "UTC"}, // Etc/UTC "Astrakhan Standard Time": {"+04", "+04"}, // Europe/Astrakhan "W. Europe Standard Time": {"CET", "CEST"}, // Europe/Berlin "GTB Standard Time": {"EET", "EEST"}, // Europe/Bucharest @@ -134,20 +136,20 @@ var abbrs = map[string]abbr{ "Romance Standard Time": {"CET", "CEST"}, // Europe/Paris "Russia Time Zone 3": {"+04", "+04"}, // Europe/Samara "Saratov Standard Time": {"+04", "+04"}, // Europe/Saratov - "Volgograd Standard Time": {"+04", "+04"}, // Europe/Volgograd + "Volgograd Standard Time": {"+03", "+03"}, // Europe/Volgograd "Central European Standard Time": {"CET", "CEST"}, // Europe/Warsaw "Mauritius Standard Time": {"+04", "+04"}, // Indian/Mauritius - "Samoa Standard Time": {"+13", "+14"}, // Pacific/Apia + "Samoa Standard Time": {"+13", "+13"}, // Pacific/Apia "New Zealand Standard Time": {"NZST", "NZDT"}, // Pacific/Auckland "Bougainville Standard Time": {"+11", "+11"}, // Pacific/Bougainville "Chatham Islands Standard Time": {"+1245", "+1345"}, // Pacific/Chatham "Easter Island Standard Time": {"-06", "-05"}, // Pacific/Easter - "Fiji Standard Time": {"+12", "+13"}, // Pacific/Fiji + "Fiji Standard Time": {"+12", "+12"}, // Pacific/Fiji "Central Pacific Standard Time": {"+11", "+11"}, // Pacific/Guadalcanal "Hawaiian Standard Time": {"HST", "HST"}, // Pacific/Honolulu "Line Islands Standard Time": {"+14", "+14"}, // Pacific/Kiritimati "Marquesas Standard Time": {"-0930", "-0930"}, // Pacific/Marquesas - "Norfolk Standard Time": {"+11", "+11"}, // Pacific/Norfolk + "Norfolk Standard Time": {"+11", "+12"}, // Pacific/Norfolk "West Pacific Standard Time": {"+10", "+10"}, // Pacific/Port_Moresby "Tonga Standard Time": {"+13", "+13"}, // Pacific/Tongatapu } From fbba58a0a4f5ff4f3aa4cfa0d494b6d2fefd068a Mon Sep 17 00:00:00 2001 From: Russ Cox Date: Wed, 1 Feb 2023 12:15:08 -0500 Subject: [PATCH 06/17] [release-branch.go1.20] cmd/link: keep go.buildinfo even with --gc-sections If you use an external linker with --gc-sections, nothing refers to .go.buildinfo, so the section is deleted, which in turns makes 'go version' fail on the binary. It is important for vulnerability scanning and the like to be able to run 'go version' on any binary. Fix this by inserting a reference to .go.buildinfo from the rodata section, which will not be GC'ed. Fixes #58222. Fixes #58224. Change-Id: I1e13e9464acaf2f5cc5e0b70476fa52b43651123 Reviewed-on: https://go-review.googlesource.com/c/go/+/464435 Run-TryBot: Russ Cox Reviewed-by: Cherry Mui Reviewed-by: Than McIntosh Auto-Submit: Russ Cox TryBot-Result: Gopher Robot Reviewed-on: https://go-review.googlesource.com/c/go/+/464796 --- .../testdata/script/version_gc_sections.txt | 24 +++++++++++++++++++ src/cmd/link/internal/ld/data.go | 13 +++++++++- 2 files changed, 36 insertions(+), 1 deletion(-) create mode 100644 src/cmd/go/testdata/script/version_gc_sections.txt diff --git a/src/cmd/go/testdata/script/version_gc_sections.txt b/src/cmd/go/testdata/script/version_gc_sections.txt new file mode 100644 index 0000000000000..4bda23ba95360 --- /dev/null +++ b/src/cmd/go/testdata/script/version_gc_sections.txt @@ -0,0 +1,24 @@ +# This test checks that external linking with --gc-sections does not strip version information. + +[short] skip +[!cgo] skip +[GOOS:aix] skip # no --gc-sections +[GOOS:darwin] skip # no --gc-sections + +go build -ldflags='-linkmode=external -extldflags=-Wl,--gc-sections' +go version hello$GOEXE +! stdout 'not a Go executable' +! stderr 'not a Go executable' + +-- go.mod -- +module hello +-- hello.go -- +package main + +/* +*/ +import "C" + +func main() { + println("hello") +} diff --git a/src/cmd/link/internal/ld/data.go b/src/cmd/link/internal/ld/data.go index 94f8fc32d6e2d..925e554b1d9db 100644 --- a/src/cmd/link/internal/ld/data.go +++ b/src/cmd/link/internal/ld/data.go @@ -1669,6 +1669,9 @@ func (ctxt *Link) dodata(symGroupType []sym.SymKind) { func (state *dodataState) allocateDataSectionForSym(seg *sym.Segment, s loader.Sym, rwx int) *sym.Section { ldr := state.ctxt.loader sname := ldr.SymName(s) + if strings.HasPrefix(sname, "go:") { + sname = ".go." + sname[len("go:"):] + } sect := addsection(ldr, state.ctxt.Arch, seg, sname, rwx) sect.Align = symalign(ldr, s) state.datsize = Rnd(state.datsize, int64(sect.Align)) @@ -2254,7 +2257,7 @@ func (ctxt *Link) buildinfo() { // Write the buildinfo symbol, which go version looks for. // The code reading this data is in package debug/buildinfo. ldr := ctxt.loader - s := ldr.CreateSymForUpdate(".go.buildinfo", 0) + s := ldr.CreateSymForUpdate("go:buildinfo", 0) s.SetType(sym.SBUILDINFO) s.SetAlign(16) // The \xff is invalid UTF-8, meant to make it less likely @@ -2276,6 +2279,14 @@ func (ctxt *Link) buildinfo() { } s.SetData(data) s.SetSize(int64(len(data))) + + // Add reference to go:buildinfo from the rodata section, + // so that external linking with -Wl,--gc-sections does not + // delete the build info. + sr := ldr.CreateSymForUpdate("go:buildinfo.ref", 0) + sr.SetType(sym.SRODATA) + sr.SetAlign(int32(ctxt.Arch.PtrSize)) + sr.AddAddr(ctxt.Arch, s.Sym()) } // appendString appends s to data, prefixed by its varint-encoded length. From a943fd0cccc6043e6a3397659f3f262544e615b2 Mon Sep 17 00:00:00 2001 From: Russ Cox Date: Wed, 8 Feb 2023 14:02:55 -0500 Subject: [PATCH 07/17] [release-branch.go1.20] runtime: skip darwin osinit_hack on ios Darwin needs the osinit_hack call to fix some bugs in the Apple libc that surface when Go programs call exec. On iOS, the functions that osinit_hack uses are not available, so signing fails. But on iOS exec is also unavailable, so the hack is not needed. Disable it there, which makes signing work again. Fixes #58323. Fixes #58419. Change-Id: I3f1472f852bb36c06854fe1f14aa27ad450c5945 Reviewed-on: https://go-review.googlesource.com/c/go/+/466516 Run-TryBot: Russ Cox Reviewed-by: Dave Anderson Reviewed-by: Michael Knyszek TryBot-Result: Gopher Robot Auto-Submit: Russ Cox Reviewed-by: Bryan Mills Reviewed-by: Than McIntosh Reviewed-on: https://go-review.googlesource.com/c/go/+/467316 --- src/runtime/sys_darwin.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/runtime/sys_darwin.go b/src/runtime/sys_darwin.go index 8bff695f5729d..5ba697e3047f4 100644 --- a/src/runtime/sys_darwin.go +++ b/src/runtime/sys_darwin.go @@ -213,7 +213,9 @@ func pthread_kill_trampoline() // //go:nosplit func osinit_hack() { - libcCall(unsafe.Pointer(abi.FuncPCABI0(osinit_hack_trampoline)), nil) + if GOOS == "darwin" { // not ios + libcCall(unsafe.Pointer(abi.FuncPCABI0(osinit_hack_trampoline)), nil) + } return } func osinit_hack_trampoline() From 1fa2deb1b1a620511a3c45fcbae895e78d4f5d40 Mon Sep 17 00:00:00 2001 From: "Bryan C. Mills" Date: Thu, 9 Feb 2023 16:37:51 -0500 Subject: [PATCH 08/17] [release-branch.go1.20] cmd/go: remove tests that assume lack of new versions of external modules MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit In general it seems ok to assume that an open-source module that did exist will continue to do so — after all, users of open-source modules already do that all the time. However, we should not assume that those modules do not publish new versions — that's really up to their maintainers to decide. Two existing tests did make that assumption for the module gopkg.in/natefinch/lumberjack.v2. Let's remove those two tests. If we need to replace them at some point, we can replace them with hermetic test-only modules (#54503) or perhaps modules owned by the Go project. Updates #58445. Fixes #58450. Change-Id: Ica8fe587d86fc41f3d8445a4cd2b8820455ae45f Reviewed-on: https://go-review.googlesource.com/c/go/+/466861 TryBot-Result: Gopher Robot Run-TryBot: Bryan Mills Reviewed-by: David Chase --- src/cmd/go/internal/modfetch/coderepo_test.go | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/src/cmd/go/internal/modfetch/coderepo_test.go b/src/cmd/go/internal/modfetch/coderepo_test.go index 553946ba369e7..8ccd9b2dca691 100644 --- a/src/cmd/go/internal/modfetch/coderepo_test.go +++ b/src/cmd/go/internal/modfetch/coderepo_test.go @@ -404,18 +404,6 @@ var codeRepoTests = []codeRepoTest{ zipSum: "h1:YJYZRsM9BHFTlVr8YADjT0cJH8uFIDtoc5NLiVqZEx8=", zipFileHash: "c15e49d58b7a4c37966cbe5bc01a0330cd5f2927e990e1839bda1d407766d9c5", }, - { - vcs: "git", - path: "gopkg.in/natefinch/lumberjack.v2", - rev: "latest", - version: "v2.0.0-20170531160350-a96e63847dc3", - name: "a96e63847dc3c67d17befa69c303767e2f84e54f", - short: "a96e63847dc3", - time: time.Date(2017, 5, 31, 16, 3, 50, 0, time.UTC), - gomod: "module gopkg.in/natefinch/lumberjack.v2\n", - zipSum: "h1:AFxeG48hTWHhDTQDk/m2gorfVHUEa9vo3tp3D7TzwjI=", - zipFileHash: "b5de0da7bbbec76709eef1ac71b6c9ff423b9fbf3bb97b56743450d4937b06d5", - }, { vcs: "git", path: "gopkg.in/natefinch/lumberjack.v2", @@ -818,11 +806,6 @@ var codeRepoVersionsTests = []struct { path: "swtch.com/testmod", versions: []string{"v1.0.0", "v1.1.1"}, }, - { - vcs: "git", - path: "gopkg.in/natefinch/lumberjack.v2", - versions: []string{"v2.0.0"}, - }, { vcs: "git", path: "vcs-test.golang.org/git/odd-tags.git", From 7628627cb236662002b53686ff0618834a9aa077 Mon Sep 17 00:00:00 2001 From: "Bryan C. Mills" Date: Wed, 8 Feb 2023 14:36:47 -0500 Subject: [PATCH 09/17] [release-branch.go1.20] cmd/go/internal/test: refresh flagdefs.go and fix test The tests for cmd/go/internal/test were not running at all due to a missed call to m.Run in TestMain. That masked a missing vet analyzer ("timeformat") and a missed update to the generator script in CL 355452. Fixes #58421. Updates #58415. Change-Id: I7b0315952967ca07a866cdaa5903478b2873eb7a Reviewed-on: https://go-review.googlesource.com/c/go/+/466635 TryBot-Result: Gopher Robot Reviewed-by: Ian Lance Taylor Auto-Submit: Bryan Mills Run-TryBot: Bryan Mills (cherry picked from commit 910f041ff0cdf90dbcd3bd22a272b9b7205a5add) Reviewed-on: https://go-review.googlesource.com/c/go/+/466855 --- src/cmd/go/internal/test/flagdefs.go | 1 + src/cmd/go/internal/test/flagdefs_test.go | 4 ++++ src/cmd/go/internal/test/genflags.go | 2 +- 3 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/cmd/go/internal/test/flagdefs.go b/src/cmd/go/internal/test/flagdefs.go index b91204ee93758..3f3709fe7e28b 100644 --- a/src/cmd/go/internal/test/flagdefs.go +++ b/src/cmd/go/internal/test/flagdefs.go @@ -66,6 +66,7 @@ var passAnalyzersToVet = map[string]bool{ "structtag": true, "testinggoroutine": true, "tests": true, + "timeformat": true, "unmarshal": true, "unreachable": true, "unsafeptr": true, diff --git a/src/cmd/go/internal/test/flagdefs_test.go b/src/cmd/go/internal/test/flagdefs_test.go index 337f136d06177..d5facb7161ab1 100644 --- a/src/cmd/go/internal/test/flagdefs_test.go +++ b/src/cmd/go/internal/test/flagdefs_test.go @@ -9,6 +9,7 @@ import ( "cmd/go/internal/test/internal/genflags" "flag" "internal/testenv" + "os" "reflect" "strings" "testing" @@ -16,6 +17,7 @@ import ( func TestMain(m *testing.M) { cfg.SetGOROOT(testenv.GOROOT(nil), false) + os.Exit(m.Run()) } func TestPassFlagToTestIncludesAllTestFlags(t *testing.T) { @@ -48,6 +50,8 @@ func TestPassFlagToTestIncludesAllTestFlags(t *testing.T) { } func TestVetAnalyzersSetIsCorrect(t *testing.T) { + testenv.MustHaveGoBuild(t) // runs 'go tool vet -flags' + vetAns, err := genflags.VetAnalyzers() if err != nil { t.Fatal(err) diff --git a/src/cmd/go/internal/test/genflags.go b/src/cmd/go/internal/test/genflags.go index 8c7554919a5c8..625f94133a147 100644 --- a/src/cmd/go/internal/test/genflags.go +++ b/src/cmd/go/internal/test/genflags.go @@ -75,7 +75,7 @@ func testFlags() []string { } switch name { - case "testlogfile", "paniconexit0", "fuzzcachedir", "fuzzworker": + case "testlogfile", "paniconexit0", "fuzzcachedir", "fuzzworker", "gocoverdir": // These flags are only for use by cmd/go. default: names = append(names, name) From 00f5d3001a7e684263307ab39c64eba3c79f279c Mon Sep 17 00:00:00 2001 From: "Bryan C. Mills" Date: Tue, 31 Jan 2023 17:21:14 -0500 Subject: [PATCH 10/17] [release-branch.go1.20] cmd/go/internal/script: retry ETXTBSY errors in scripts Fixes #58431. Updates #58019. Change-Id: Ib25d668bfede6e87a3786f44bdc0db1027e3ebec Reviewed-on: https://go-review.googlesource.com/c/go/+/463748 TryBot-Result: Gopher Robot Auto-Submit: Bryan Mills Run-TryBot: Bryan Mills Reviewed-by: Ian Lance Taylor (cherry picked from commit 23c0121e4eb259cc1087d0f79a0803cbc71f500b) Reviewed-on: https://go-review.googlesource.com/c/go/+/466856 Reviewed-by: David Chase --- src/cmd/go/internal/script/cmds.go | 46 ++++++++++++++++-------- src/cmd/go/internal/script/cmds_other.go | 11 ++++++ src/cmd/go/internal/script/cmds_posix.go | 16 +++++++++ 3 files changed, 58 insertions(+), 15 deletions(-) create mode 100644 src/cmd/go/internal/script/cmds_other.go create mode 100644 src/cmd/go/internal/script/cmds_posix.go diff --git a/src/cmd/go/internal/script/cmds.go b/src/cmd/go/internal/script/cmds.go index e0eaad4c43d19..666d2d62d30f5 100644 --- a/src/cmd/go/internal/script/cmds.go +++ b/src/cmd/go/internal/script/cmds.go @@ -432,21 +432,37 @@ func Exec(cancel func(*exec.Cmd) error, waitDelay time.Duration) Cmd { } func startCommand(s *State, name, path string, args []string, cancel func(*exec.Cmd) error, waitDelay time.Duration) (WaitFunc, error) { - var stdoutBuf, stderrBuf strings.Builder - cmd := exec.CommandContext(s.Context(), path, args...) - if cancel == nil { - cmd.Cancel = nil - } else { - cmd.Cancel = func() error { return cancel(cmd) } - } - cmd.WaitDelay = waitDelay - cmd.Args[0] = name - cmd.Dir = s.Getwd() - cmd.Env = s.env - cmd.Stdout = &stdoutBuf - cmd.Stderr = &stderrBuf - if err := cmd.Start(); err != nil { - return nil, err + var ( + cmd *exec.Cmd + stdoutBuf, stderrBuf strings.Builder + ) + for { + cmd = exec.CommandContext(s.Context(), path, args...) + if cancel == nil { + cmd.Cancel = nil + } else { + cmd.Cancel = func() error { return cancel(cmd) } + } + cmd.WaitDelay = waitDelay + cmd.Args[0] = name + cmd.Dir = s.Getwd() + cmd.Env = s.env + cmd.Stdout = &stdoutBuf + cmd.Stderr = &stderrBuf + err := cmd.Start() + if err == nil { + break + } + if isETXTBSY(err) { + // If the script (or its host process) just wrote the executable we're + // trying to run, a fork+exec in another thread may be holding open the FD + // that we used to write the executable (see https://go.dev/issue/22315). + // Since the descriptor should have CLOEXEC set, the problem should + // resolve as soon as the forked child reaches its exec call. + // Keep retrying until that happens. + } else { + return nil, err + } } wait := func(s *State) (stdout, stderr string, err error) { diff --git a/src/cmd/go/internal/script/cmds_other.go b/src/cmd/go/internal/script/cmds_other.go new file mode 100644 index 0000000000000..847b225ae6498 --- /dev/null +++ b/src/cmd/go/internal/script/cmds_other.go @@ -0,0 +1,11 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !(unix || windows) + +package script + +func isETXTBSY(err error) bool { + return false +} diff --git a/src/cmd/go/internal/script/cmds_posix.go b/src/cmd/go/internal/script/cmds_posix.go new file mode 100644 index 0000000000000..2525f6e7529d8 --- /dev/null +++ b/src/cmd/go/internal/script/cmds_posix.go @@ -0,0 +1,16 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build unix || windows + +package script + +import ( + "errors" + "syscall" +) + +func isETXTBSY(err error) bool { + return errors.Is(err, syscall.ETXTBSY) +} From 3a04b6e12ef0e5a0c608f82051943408bd6f28bd Mon Sep 17 00:00:00 2001 From: Frederic Branczyk Date: Wed, 8 Feb 2023 17:59:27 +0000 Subject: [PATCH 11/17] [release-branch.go1.20] cmd/compile/internal/pgo: fix hard-coded PGO sample data position This patch detects at which index position profiling samples that have the value-type samples count are, instead of the previously hard-coded position of index 1. Runtime generated profiles always generate CPU profiling data with the 0 index being CPU nanoseconds, and samples count at index 1, which is why this previously hasn't come up. This is a redo of CL 465135, now allowing empty profiles. Note that preprocessProfileGraph will already cause pgo.New to return nil for empty profiles. For #58292 For #58309 Change-Id: Ia6c94f0793f6ca9b0882b5e2c4d34f38e600c1e3 Reviewed-on: https://go-review.googlesource.com/c/go/+/467375 Run-TryBot: Michael Pratt TryBot-Result: Gopher Robot Reviewed-by: Austin Clements --- src/cmd/compile/internal/pgo/irgraph.go | 23 ++++++- src/cmd/compile/internal/test/pgo_inl_test.go | 68 +++++++++++++++++++ 2 files changed, 90 insertions(+), 1 deletion(-) diff --git a/src/cmd/compile/internal/pgo/irgraph.go b/src/cmd/compile/internal/pgo/irgraph.go index bf11e365f1042..bb5df50e3ad32 100644 --- a/src/cmd/compile/internal/pgo/irgraph.go +++ b/src/cmd/compile/internal/pgo/irgraph.go @@ -140,9 +140,30 @@ func New(profileFile string) *Profile { return nil } + if len(profile.Sample) == 0 { + // We accept empty profiles, but there is nothing to do. + return nil + } + + valueIndex := -1 + for i, s := range profile.SampleType { + // Samples count is the raw data collected, and CPU nanoseconds is just + // a scaled version of it, so either one we can find is fine. + if (s.Type == "samples" && s.Unit == "count") || + (s.Type == "cpu" && s.Unit == "nanoseconds") { + valueIndex = i + break + } + } + + if valueIndex == -1 { + log.Fatal("failed to find CPU samples count or CPU nanoseconds value-types in profile.") + return nil + } + g := newGraph(profile, &Options{ CallTree: false, - SampleValue: func(v []int64) int64 { return v[1] }, + SampleValue: func(v []int64) int64 { return v[valueIndex] }, }) p := &Profile{ diff --git a/src/cmd/compile/internal/test/pgo_inl_test.go b/src/cmd/compile/internal/test/pgo_inl_test.go index 2f6391fded265..4d6b5a134a0e2 100644 --- a/src/cmd/compile/internal/test/pgo_inl_test.go +++ b/src/cmd/compile/internal/test/pgo_inl_test.go @@ -7,6 +7,7 @@ package test import ( "bufio" "fmt" + "internal/profile" "internal/testenv" "io" "os" @@ -213,6 +214,73 @@ func TestPGOIntendedInliningShiftedLines(t *testing.T) { testPGOIntendedInlining(t, dir) } +// TestPGOSingleIndex tests that the sample index can not be 1 and compilation +// will not fail. All it should care about is that the sample type is either +// CPU nanoseconds or samples count, whichever it finds first. +func TestPGOSingleIndex(t *testing.T) { + for _, tc := range []struct { + originalIndex int + }{{ + // The `testdata/pgo/inline/inline_hot.pprof` file is a standard CPU + // profile as the runtime would generate. The 0 index contains the + // value-type samples and value-unit count. The 1 index contains the + // value-type cpu and value-unit nanoseconds. These tests ensure that + // the compiler can work with profiles that only have a single index, + // but are either samples count or CPU nanoseconds. + originalIndex: 0, + }, { + originalIndex: 1, + }} { + t.Run(fmt.Sprintf("originalIndex=%d", tc.originalIndex), func(t *testing.T) { + wd, err := os.Getwd() + if err != nil { + t.Fatalf("error getting wd: %v", err) + } + srcDir := filepath.Join(wd, "testdata/pgo/inline") + + // Copy the module to a scratch location so we can add a go.mod. + dir := t.TempDir() + + originalPprofFile, err := os.Open(filepath.Join(srcDir, "inline_hot.pprof")) + if err != nil { + t.Fatalf("error opening inline_hot.pprof: %v", err) + } + defer originalPprofFile.Close() + + p, err := profile.Parse(originalPprofFile) + if err != nil { + t.Fatalf("error parsing inline_hot.pprof: %v", err) + } + + // Move the samples count value-type to the 0 index. + p.SampleType = []*profile.ValueType{p.SampleType[tc.originalIndex]} + + // Ensure we only have a single set of sample values. + for _, s := range p.Sample { + s.Value = []int64{s.Value[tc.originalIndex]} + } + + modifiedPprofFile, err := os.Create(filepath.Join(dir, "inline_hot.pprof")) + if err != nil { + t.Fatalf("error creating inline_hot.pprof: %v", err) + } + defer modifiedPprofFile.Close() + + if err := p.Write(modifiedPprofFile); err != nil { + t.Fatalf("error writing inline_hot.pprof: %v", err) + } + + for _, file := range []string{"inline_hot.go", "inline_hot_test.go"} { + if err := copyFile(filepath.Join(dir, file), filepath.Join(srcDir, file)); err != nil { + t.Fatalf("error copying %s: %v", file, err) + } + } + + testPGOIntendedInlining(t, dir) + }) + } +} + func copyFile(dst, src string) error { s, err := os.Open(src) if err != nil { From bdf07c2e168baf736e4c057279ca12a4d674f18c Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Mon, 12 Dec 2022 16:43:37 -0800 Subject: [PATCH 12/17] [release-branch.go1.20] path/filepath: do not Clean("a/../c:/b") into c:\b on Windows Do not permit Clean to convert a relative path into one starting with a drive reference. This change causes Clean to insert a . path element at the start of a path when the original path does not start with a volume name, and the first path element would contain a colon. This may introduce a spurious but harmless . path element under some circumstances. For example, Clean("a/../b:/../c") becomes `.\c`. This reverts CL 401595, since the change here supersedes the one in that CL. Thanks to RyotaK (https://twitter.com/ryotkak) for reporting this issue. Updates #57274 Fixes #57276 Fixes CVE-2022-41722 Change-Id: I837446285a03aa74c79d7642720e01f354c2ca17 Reviewed-on: https://team-review.git.corp.google.com/c/golang/go-private/+/1675249 Reviewed-by: Roland Shoemaker Run-TryBot: Damien Neil Reviewed-by: Julie Qiu TryBot-Result: Security TryBots (cherry picked from commit 8ca37f4813ef2f64600c92b83f17c9f3ca6c03a5) Reviewed-on: https://team-review.git.corp.google.com/c/golang/go-private/+/1728944 Run-TryBot: Roland Shoemaker Reviewed-by: Tatiana Bradley Reviewed-by: Damien Neil Reviewed-on: https://go-review.googlesource.com/c/go/+/468119 Reviewed-by: Than McIntosh Run-TryBot: Michael Pratt TryBot-Result: Gopher Robot Auto-Submit: Michael Pratt --- src/path/filepath/path.go | 27 +++++++++++++------------- src/path/filepath/path_test.go | 8 ++++++++ src/path/filepath/path_windows_test.go | 2 +- 3 files changed, 23 insertions(+), 14 deletions(-) diff --git a/src/path/filepath/path.go b/src/path/filepath/path.go index a6578cbb728b9..32dd887998650 100644 --- a/src/path/filepath/path.go +++ b/src/path/filepath/path.go @@ -15,6 +15,7 @@ import ( "errors" "io/fs" "os" + "runtime" "sort" "strings" ) @@ -117,21 +118,9 @@ func Clean(path string) string { case os.IsPathSeparator(path[r]): // empty path element r++ - case path[r] == '.' && r+1 == n: + case path[r] == '.' && (r+1 == n || os.IsPathSeparator(path[r+1])): // . element r++ - case path[r] == '.' && os.IsPathSeparator(path[r+1]): - // ./ element - r++ - - for r < len(path) && os.IsPathSeparator(path[r]) { - r++ - } - if out.w == 0 && volumeNameLen(path[r:]) > 0 { - // When joining prefix "." and an absolute path on Windows, - // the prefix should not be removed. - out.append('.') - } case path[r] == '.' && path[r+1] == '.' && (r+2 == n || os.IsPathSeparator(path[r+2])): // .. element: remove to last separator r += 2 @@ -157,6 +146,18 @@ func Clean(path string) string { if rooted && out.w != 1 || !rooted && out.w != 0 { out.append(Separator) } + // If a ':' appears in the path element at the start of a Windows path, + // insert a .\ at the beginning to avoid converting relative paths + // like a/../c: into c:. + if runtime.GOOS == "windows" && out.w == 0 && out.volLen == 0 && r != 0 { + for i := r; i < n && !os.IsPathSeparator(path[i]); i++ { + if path[i] == ':' { + out.append('.') + out.append(Separator) + break + } + } + } // copy element for ; r < n && !os.IsPathSeparator(path[r]); r++ { out.append(path[r]) diff --git a/src/path/filepath/path_test.go b/src/path/filepath/path_test.go index 6647444852667..697bcc672d73c 100644 --- a/src/path/filepath/path_test.go +++ b/src/path/filepath/path_test.go @@ -106,6 +106,13 @@ var wincleantests = []PathTest{ {`//abc`, `\\abc`}, {`///abc`, `\\\abc`}, {`//abc//`, `\\abc\\`}, + + // Don't allow cleaning to move an element with a colon to the start of the path. + {`a/../c:`, `.\c:`}, + {`a\..\c:`, `.\c:`}, + {`a/../c:/a`, `.\c:\a`}, + {`a/../../c:`, `..\c:`}, + {`foo:bar`, `foo:bar`}, } func TestClean(t *testing.T) { @@ -174,6 +181,7 @@ var winislocaltests = []IsLocalTest{ {`C:`, false}, {`C:\a`, false}, {`..\a`, false}, + {`a/../c:`, false}, {`CONIN$`, false}, {`conin$`, false}, {`CONOUT$`, false}, diff --git a/src/path/filepath/path_windows_test.go b/src/path/filepath/path_windows_test.go index e37dddceadbdf..c8c7eefcc0bb4 100644 --- a/src/path/filepath/path_windows_test.go +++ b/src/path/filepath/path_windows_test.go @@ -542,7 +542,7 @@ func TestIssue52476(t *testing.T) { }{ {`..\.`, `C:`, `..\C:`}, {`..`, `C:`, `..\C:`}, - {`.`, `:`, `:`}, + {`.`, `:`, `.\:`}, {`.`, `C:`, `.\C:`}, {`.`, `C:/a/b/../c`, `.\C:\a\c`}, {`.`, `\C:`, `.\C:`}, From 53b43607d92e9738067c93829bd799441eda8034 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Wed, 25 Jan 2023 09:27:01 -0800 Subject: [PATCH 13/17] [release-branch.go1.20] mime/multipart: limit memory/inode consumption of ReadForm Reader.ReadForm is documented as storing "up to maxMemory bytes + 10MB" in memory. Parsed forms can consume substantially more memory than this limit, since ReadForm does not account for map entry overhead and MIME headers. In addition, while the amount of disk memory consumed by ReadForm can be constrained by limiting the size of the parsed input, ReadForm will create one temporary file per form part stored on disk, potentially consuming a large number of inodes. Update ReadForm's memory accounting to include part names, MIME headers, and map entry overhead. Update ReadForm to store all on-disk file parts in a single temporary file. Files returned by FileHeader.Open are documented as having a concrete type of *os.File when a file is stored on disk. The change to use a single temporary file for all parts means that this is no longer the case when a form contains more than a single file part stored on disk. The previous behavior of storing each file part in a separate disk file may be reenabled with GODEBUG=multipartfiles=distinct. Update Reader.NextPart and Reader.NextRawPart to set a 10MiB cap on the size of MIME headers. Thanks to Jakob Ackermann (@das7pad) for reporting this issue. Updates #58006 Fixes #58363 Fixes CVE-2022-41725 Change-Id: Ibd780a6c4c83ac8bcfd3cbe344f042e9940f2eab Reviewed-on: https://team-review.git.corp.google.com/c/golang/go-private/+/1714276 Reviewed-by: Julie Qiu TryBot-Result: Security TryBots Reviewed-by: Roland Shoemaker Run-TryBot: Damien Neil (cherry picked from commit 7d0da0029bfbe3228cc5216ced8c7b3184eb517d) Reviewed-on: https://team-review.git.corp.google.com/c/golang/go-private/+/1728950 Reviewed-by: Damien Neil Run-TryBot: Roland Shoemaker Reviewed-by: Tatiana Bradley Reviewed-on: https://go-review.googlesource.com/c/go/+/468120 Auto-Submit: Michael Pratt Run-TryBot: Michael Pratt Reviewed-by: Than McIntosh TryBot-Result: Gopher Robot --- src/mime/multipart/formdata.go | 133 ++++++++++++++++++++----- src/mime/multipart/formdata_test.go | 141 ++++++++++++++++++++++++++- src/mime/multipart/multipart.go | 25 +++-- src/mime/multipart/readmimeheader.go | 14 +++ src/net/http/request_test.go | 2 +- src/net/textproto/reader.go | 20 +++- 6 files changed, 297 insertions(+), 38 deletions(-) create mode 100644 src/mime/multipart/readmimeheader.go diff --git a/src/mime/multipart/formdata.go b/src/mime/multipart/formdata.go index fca5f9e15fb21..41bc886d1679d 100644 --- a/src/mime/multipart/formdata.go +++ b/src/mime/multipart/formdata.go @@ -7,6 +7,7 @@ package multipart import ( "bytes" "errors" + "internal/godebug" "io" "math" "net/textproto" @@ -31,25 +32,61 @@ func (r *Reader) ReadForm(maxMemory int64) (*Form, error) { return r.readForm(maxMemory) } +var multipartFiles = godebug.New("multipartfiles") + func (r *Reader) readForm(maxMemory int64) (_ *Form, err error) { form := &Form{make(map[string][]string), make(map[string][]*FileHeader)} + var ( + file *os.File + fileOff int64 + ) + numDiskFiles := 0 + combineFiles := multipartFiles.Value() != "distinct" defer func() { + if file != nil { + if cerr := file.Close(); err == nil { + err = cerr + } + } + if combineFiles && numDiskFiles > 1 { + for _, fhs := range form.File { + for _, fh := range fhs { + fh.tmpshared = true + } + } + } if err != nil { form.RemoveAll() + if file != nil { + os.Remove(file.Name()) + } } }() - // Reserve an additional 10 MB for non-file parts. - maxValueBytes := maxMemory + int64(10<<20) - if maxValueBytes <= 0 { + // maxFileMemoryBytes is the maximum bytes of file data we will store in memory. + // Data past this limit is written to disk. + // This limit strictly applies to content, not metadata (filenames, MIME headers, etc.), + // since metadata is always stored in memory, not disk. + // + // maxMemoryBytes is the maximum bytes we will store in memory, including file content, + // non-file part values, metdata, and map entry overhead. + // + // We reserve an additional 10 MB in maxMemoryBytes for non-file data. + // + // The relationship between these parameters, as well as the overly-large and + // unconfigurable 10 MB added on to maxMemory, is unfortunate but difficult to change + // within the constraints of the API as documented. + maxFileMemoryBytes := maxMemory + maxMemoryBytes := maxMemory + int64(10<<20) + if maxMemoryBytes <= 0 { if maxMemory < 0 { - maxValueBytes = 0 + maxMemoryBytes = 0 } else { - maxValueBytes = math.MaxInt64 + maxMemoryBytes = math.MaxInt64 } } for { - p, err := r.NextPart() + p, err := r.nextPart(false, maxMemoryBytes) if err == io.EOF { break } @@ -63,16 +100,27 @@ func (r *Reader) readForm(maxMemory int64) (_ *Form, err error) { } filename := p.FileName() + // Multiple values for the same key (one map entry, longer slice) are cheaper + // than the same number of values for different keys (many map entries), but + // using a consistent per-value cost for overhead is simpler. + maxMemoryBytes -= int64(len(name)) + maxMemoryBytes -= 100 // map overhead + if maxMemoryBytes < 0 { + // We can't actually take this path, since nextPart would already have + // rejected the MIME headers for being too large. Check anyway. + return nil, ErrMessageTooLarge + } + var b bytes.Buffer if filename == "" { // value, store as string in memory - n, err := io.CopyN(&b, p, maxValueBytes+1) + n, err := io.CopyN(&b, p, maxMemoryBytes+1) if err != nil && err != io.EOF { return nil, err } - maxValueBytes -= n - if maxValueBytes < 0 { + maxMemoryBytes -= n + if maxMemoryBytes < 0 { return nil, ErrMessageTooLarge } form.Value[name] = append(form.Value[name], b.String()) @@ -80,35 +128,45 @@ func (r *Reader) readForm(maxMemory int64) (_ *Form, err error) { } // file, store in memory or on disk + maxMemoryBytes -= mimeHeaderSize(p.Header) + if maxMemoryBytes < 0 { + return nil, ErrMessageTooLarge + } fh := &FileHeader{ Filename: filename, Header: p.Header, } - n, err := io.CopyN(&b, p, maxMemory+1) + n, err := io.CopyN(&b, p, maxFileMemoryBytes+1) if err != nil && err != io.EOF { return nil, err } - if n > maxMemory { - // too big, write to disk and flush buffer - file, err := os.CreateTemp("", "multipart-") - if err != nil { - return nil, err + if n > maxFileMemoryBytes { + if file == nil { + file, err = os.CreateTemp(r.tempDir, "multipart-") + if err != nil { + return nil, err + } } + numDiskFiles++ size, err := io.Copy(file, io.MultiReader(&b, p)) - if cerr := file.Close(); err == nil { - err = cerr - } if err != nil { - os.Remove(file.Name()) return nil, err } fh.tmpfile = file.Name() fh.Size = size + fh.tmpoff = fileOff + fileOff += size + if !combineFiles { + if err := file.Close(); err != nil { + return nil, err + } + file = nil + } } else { fh.content = b.Bytes() fh.Size = int64(len(fh.content)) - maxMemory -= n - maxValueBytes -= n + maxFileMemoryBytes -= n + maxMemoryBytes -= n } form.File[name] = append(form.File[name], fh) } @@ -116,6 +174,17 @@ func (r *Reader) readForm(maxMemory int64) (_ *Form, err error) { return form, nil } +func mimeHeaderSize(h textproto.MIMEHeader) (size int64) { + for k, vs := range h { + size += int64(len(k)) + size += 100 // map entry overhead + for _, v := range vs { + size += int64(len(v)) + } + } + return size +} + // Form is a parsed multipart form. // Its File parts are stored either in memory or on disk, // and are accessible via the *FileHeader's Open method. @@ -133,7 +202,7 @@ func (f *Form) RemoveAll() error { for _, fh := range fhs { if fh.tmpfile != "" { e := os.Remove(fh.tmpfile) - if e != nil && err == nil { + if e != nil && !errors.Is(e, os.ErrNotExist) && err == nil { err = e } } @@ -148,15 +217,25 @@ type FileHeader struct { Header textproto.MIMEHeader Size int64 - content []byte - tmpfile string + content []byte + tmpfile string + tmpoff int64 + tmpshared bool } // Open opens and returns the FileHeader's associated File. func (fh *FileHeader) Open() (File, error) { if b := fh.content; b != nil { r := io.NewSectionReader(bytes.NewReader(b), 0, int64(len(b))) - return sectionReadCloser{r}, nil + return sectionReadCloser{r, nil}, nil + } + if fh.tmpshared { + f, err := os.Open(fh.tmpfile) + if err != nil { + return nil, err + } + r := io.NewSectionReader(f, fh.tmpoff, fh.Size) + return sectionReadCloser{r, f}, nil } return os.Open(fh.tmpfile) } @@ -175,8 +254,12 @@ type File interface { type sectionReadCloser struct { *io.SectionReader + io.Closer } func (rc sectionReadCloser) Close() error { + if rc.Closer != nil { + return rc.Closer.Close() + } return nil } diff --git a/src/mime/multipart/formdata_test.go b/src/mime/multipart/formdata_test.go index 8a4eabcee038a..8a862be717415 100644 --- a/src/mime/multipart/formdata_test.go +++ b/src/mime/multipart/formdata_test.go @@ -5,8 +5,11 @@ package multipart import ( + "bytes" + "fmt" "io" "math" + "net/textproto" "os" "strings" "testing" @@ -207,8 +210,8 @@ Content-Disposition: form-data; name="largetext" maxMemory int64 err error }{ - {"smaller", 50, nil}, - {"exact-fit", 25, nil}, + {"smaller", 50 + int64(len("largetext")) + 100, nil}, + {"exact-fit", 25 + int64(len("largetext")) + 100, nil}, {"too-large", 0, ErrMessageTooLarge}, } for _, tc := range testCases { @@ -223,7 +226,7 @@ Content-Disposition: form-data; name="largetext" defer f.RemoveAll() } if tc.err != err { - t.Fatalf("ReadForm error - got: %v; expected: %v", tc.err, err) + t.Fatalf("ReadForm error - got: %v; expected: %v", err, tc.err) } if err == nil { if g := f.Value["largetext"][0]; g != largeTextValue { @@ -233,3 +236,135 @@ Content-Disposition: form-data; name="largetext" }) } } + +// TestReadForm_MetadataTooLarge verifies that we account for the size of field names, +// MIME headers, and map entry overhead while limiting the memory consumption of parsed forms. +func TestReadForm_MetadataTooLarge(t *testing.T) { + for _, test := range []struct { + name string + f func(*Writer) + }{{ + name: "large name", + f: func(fw *Writer) { + name := strings.Repeat("a", 10<<20) + w, _ := fw.CreateFormField(name) + w.Write([]byte("value")) + }, + }, { + name: "large MIME header", + f: func(fw *Writer) { + h := make(textproto.MIMEHeader) + h.Set("Content-Disposition", `form-data; name="a"`) + h.Set("X-Foo", strings.Repeat("a", 10<<20)) + w, _ := fw.CreatePart(h) + w.Write([]byte("value")) + }, + }, { + name: "many parts", + f: func(fw *Writer) { + for i := 0; i < 110000; i++ { + w, _ := fw.CreateFormField("f") + w.Write([]byte("v")) + } + }, + }} { + t.Run(test.name, func(t *testing.T) { + var buf bytes.Buffer + fw := NewWriter(&buf) + test.f(fw) + if err := fw.Close(); err != nil { + t.Fatal(err) + } + fr := NewReader(&buf, fw.Boundary()) + _, err := fr.ReadForm(0) + if err != ErrMessageTooLarge { + t.Errorf("fr.ReadForm() = %v, want ErrMessageTooLarge", err) + } + }) + } +} + +// TestReadForm_ManyFiles_Combined tests that a multipart form containing many files only +// results in a single on-disk file. +func TestReadForm_ManyFiles_Combined(t *testing.T) { + const distinct = false + testReadFormManyFiles(t, distinct) +} + +// TestReadForm_ManyFiles_Distinct tests that setting GODEBUG=multipartfiles=distinct +// results in every file in a multipart form being placed in a distinct on-disk file. +func TestReadForm_ManyFiles_Distinct(t *testing.T) { + t.Setenv("GODEBUG", "multipartfiles=distinct") + const distinct = true + testReadFormManyFiles(t, distinct) +} + +func testReadFormManyFiles(t *testing.T, distinct bool) { + var buf bytes.Buffer + fw := NewWriter(&buf) + const numFiles = 10 + for i := 0; i < numFiles; i++ { + name := fmt.Sprint(i) + w, err := fw.CreateFormFile(name, name) + if err != nil { + t.Fatal(err) + } + w.Write([]byte(name)) + } + if err := fw.Close(); err != nil { + t.Fatal(err) + } + fr := NewReader(&buf, fw.Boundary()) + fr.tempDir = t.TempDir() + form, err := fr.ReadForm(0) + if err != nil { + t.Fatal(err) + } + for i := 0; i < numFiles; i++ { + name := fmt.Sprint(i) + if got := len(form.File[name]); got != 1 { + t.Fatalf("form.File[%q] has %v entries, want 1", name, got) + } + fh := form.File[name][0] + file, err := fh.Open() + if err != nil { + t.Fatalf("form.File[%q].Open() = %v", name, err) + } + if distinct { + if _, ok := file.(*os.File); !ok { + t.Fatalf("form.File[%q].Open: %T, want *os.File", name, file) + } + } + got, err := io.ReadAll(file) + file.Close() + if string(got) != name || err != nil { + t.Fatalf("read form.File[%q]: %q, %v; want %q, nil", name, string(got), err, name) + } + } + dir, err := os.Open(fr.tempDir) + if err != nil { + t.Fatal(err) + } + defer dir.Close() + names, err := dir.Readdirnames(0) + if err != nil { + t.Fatal(err) + } + wantNames := 1 + if distinct { + wantNames = numFiles + } + if len(names) != wantNames { + t.Fatalf("temp dir contains %v files; want 1", len(names)) + } + if err := form.RemoveAll(); err != nil { + t.Fatalf("form.RemoveAll() = %v", err) + } + names, err = dir.Readdirnames(0) + if err != nil { + t.Fatal(err) + } + if len(names) != 0 { + t.Fatalf("temp dir contains %v files; want 0", len(names)) + } +} diff --git a/src/mime/multipart/multipart.go b/src/mime/multipart/multipart.go index b3a904f0aff34..86ea926346eb5 100644 --- a/src/mime/multipart/multipart.go +++ b/src/mime/multipart/multipart.go @@ -128,12 +128,12 @@ func (r *stickyErrorReader) Read(p []byte) (n int, _ error) { return n, r.err } -func newPart(mr *Reader, rawPart bool) (*Part, error) { +func newPart(mr *Reader, rawPart bool, maxMIMEHeaderSize int64) (*Part, error) { bp := &Part{ Header: make(map[string][]string), mr: mr, } - if err := bp.populateHeaders(); err != nil { + if err := bp.populateHeaders(maxMIMEHeaderSize); err != nil { return nil, err } bp.r = partReader{bp} @@ -149,12 +149,16 @@ func newPart(mr *Reader, rawPart bool) (*Part, error) { return bp, nil } -func (p *Part) populateHeaders() error { +func (p *Part) populateHeaders(maxMIMEHeaderSize int64) error { r := textproto.NewReader(p.mr.bufReader) - header, err := r.ReadMIMEHeader() + header, err := readMIMEHeader(r, maxMIMEHeaderSize) if err == nil { p.Header = header } + // TODO: Add a distinguishable error to net/textproto. + if err != nil && err.Error() == "message too large" { + err = ErrMessageTooLarge + } return err } @@ -311,6 +315,7 @@ func (p *Part) Close() error { // isn't supported. type Reader struct { bufReader *bufio.Reader + tempDir string // used in tests currentPart *Part partsRead int @@ -321,6 +326,10 @@ type Reader struct { dashBoundary []byte // "--boundary" } +// maxMIMEHeaderSize is the maximum size of a MIME header we will parse, +// including header keys, values, and map overhead. +const maxMIMEHeaderSize = 10 << 20 + // NextPart returns the next part in the multipart or an error. // When there are no more parts, the error io.EOF is returned. // @@ -328,7 +337,7 @@ type Reader struct { // has a value of "quoted-printable", that header is instead // hidden and the body is transparently decoded during Read calls. func (r *Reader) NextPart() (*Part, error) { - return r.nextPart(false) + return r.nextPart(false, maxMIMEHeaderSize) } // NextRawPart returns the next part in the multipart or an error. @@ -337,10 +346,10 @@ func (r *Reader) NextPart() (*Part, error) { // Unlike NextPart, it does not have special handling for // "Content-Transfer-Encoding: quoted-printable". func (r *Reader) NextRawPart() (*Part, error) { - return r.nextPart(true) + return r.nextPart(true, maxMIMEHeaderSize) } -func (r *Reader) nextPart(rawPart bool) (*Part, error) { +func (r *Reader) nextPart(rawPart bool, maxMIMEHeaderSize int64) (*Part, error) { if r.currentPart != nil { r.currentPart.Close() } @@ -365,7 +374,7 @@ func (r *Reader) nextPart(rawPart bool) (*Part, error) { if r.isBoundaryDelimiterLine(line) { r.partsRead++ - bp, err := newPart(r, rawPart) + bp, err := newPart(r, rawPart, maxMIMEHeaderSize) if err != nil { return nil, err } diff --git a/src/mime/multipart/readmimeheader.go b/src/mime/multipart/readmimeheader.go new file mode 100644 index 0000000000000..6836928c9e8b4 --- /dev/null +++ b/src/mime/multipart/readmimeheader.go @@ -0,0 +1,14 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +package multipart + +import ( + "net/textproto" + _ "unsafe" // for go:linkname +) + +// readMIMEHeader is defined in package net/textproto. +// +//go:linkname readMIMEHeader net/textproto.readMIMEHeader +func readMIMEHeader(r *textproto.Reader, lim int64) (textproto.MIMEHeader, error) diff --git a/src/net/http/request_test.go b/src/net/http/request_test.go index 686a8699fb08d..23e49d6b8e354 100644 --- a/src/net/http/request_test.go +++ b/src/net/http/request_test.go @@ -1097,7 +1097,7 @@ func testMissingFile(t *testing.T, req *Request) { t.Errorf("FormFile file = %v, want nil", f) } if fh != nil { - t.Errorf("FormFile file header = %q, want nil", fh) + t.Errorf("FormFile file header = %v, want nil", fh) } if err != ErrMissingFile { t.Errorf("FormFile err = %q, want ErrMissingFile", err) diff --git a/src/net/textproto/reader.go b/src/net/textproto/reader.go index 4e4999b3c9534..8e800088c1fe7 100644 --- a/src/net/textproto/reader.go +++ b/src/net/textproto/reader.go @@ -7,8 +7,10 @@ package textproto import ( "bufio" "bytes" + "errors" "fmt" "io" + "math" "strconv" "strings" "sync" @@ -477,6 +479,12 @@ var colon = []byte(":") // "Long-Key": {"Even Longer Value"}, // } func (r *Reader) ReadMIMEHeader() (MIMEHeader, error) { + return readMIMEHeader(r, math.MaxInt64) +} + +// readMIMEHeader is a version of ReadMIMEHeader which takes a limit on the header size. +// It is called by the mime/multipart package. +func readMIMEHeader(r *Reader, lim int64) (MIMEHeader, error) { // Avoid lots of small slice allocations later by allocating one // large one ahead of time which we'll cut up into smaller // slices. If this isn't big enough later, we allocate small ones. @@ -526,9 +534,19 @@ func (r *Reader) ReadMIMEHeader() (MIMEHeader, error) { } // Skip initial spaces in value. - value := strings.TrimLeft(string(v), " \t") + value := string(bytes.TrimLeft(v, " \t")) vv := m[key] + if vv == nil { + lim -= int64(len(key)) + lim -= 100 // map entry overhead + } + lim -= int64(len(value)) + if lim < 0 { + // TODO: This should be a distinguishable error (ErrMessageTooLarge) + // to allow mime/multipart to detect it. + return m, errors.New("message too large") + } if vv == nil && len(strs) > 0 { // More than likely this will be a single-element key. // Most headers aren't multi-valued. From 5286ac4ed85a3771cc8a982041fe36dc53d7dc3b Mon Sep 17 00:00:00 2001 From: Roland Shoemaker Date: Wed, 14 Dec 2022 09:43:16 -0800 Subject: [PATCH 14/17] [release-branch.go1.20] crypto/tls: replace all usages of BytesOrPanic Message marshalling makes use of BytesOrPanic a lot, under the assumption that it will never panic. This assumption was incorrect, and specifically crafted handshakes could trigger panics. Rather than just surgically replacing the usages of BytesOrPanic in paths that could panic, replace all usages of it with proper error returns in case there are other ways of triggering panics which we didn't find. In one specific case, the tree routed by expandLabel, we replace the usage of BytesOrPanic, but retain a panic. This function already explicitly panicked elsewhere, and returning an error from it becomes rather painful because it requires changing a large number of APIs. The marshalling is unlikely to ever panic, as the inputs are all either fixed length, or already limited to the sizes required. If it were to panic, it'd likely only be during development. A close inspection shows no paths for a user to cause a panic currently. This patches ends up being rather large, since it requires routing errors back through functions which previously had no error returns. Where possible I've tried to use helpers that reduce the verbosity of frequently repeated stanzas, and to make the diffs as minimal as possible. Thanks to Marten Seemann for reporting this issue. Updates #58001 Fixes #58359 Fixes CVE-2022-41724 Change-Id: Ieb55867ef0a3e1e867b33f09421932510cb58851 Reviewed-on: https://team-review.git.corp.google.com/c/golang/go-private/+/1679436 Reviewed-by: Julie Qiu TryBot-Result: Security TryBots Run-TryBot: Roland Shoemaker Reviewed-by: Damien Neil (cherry picked from commit 1d4e6ca9454f6cf81d30c5361146fb5988f1b5f6) Reviewed-on: https://team-review.git.corp.google.com/c/golang/go-private/+/1728205 Reviewed-by: Tatiana Bradley Reviewed-on: https://go-review.googlesource.com/c/go/+/468121 Reviewed-by: Than McIntosh Auto-Submit: Michael Pratt TryBot-Bypass: Michael Pratt Run-TryBot: Michael Pratt --- src/crypto/tls/boring_test.go | 2 +- src/crypto/tls/common.go | 2 +- src/crypto/tls/conn.go | 46 +- src/crypto/tls/handshake_client.go | 95 +-- src/crypto/tls/handshake_client_test.go | 4 +- src/crypto/tls/handshake_client_tls13.go | 74 ++- src/crypto/tls/handshake_messages.go | 716 +++++++++++----------- src/crypto/tls/handshake_messages_test.go | 19 +- src/crypto/tls/handshake_server.go | 73 ++- src/crypto/tls/handshake_server_test.go | 31 +- src/crypto/tls/handshake_server_tls13.go | 71 ++- src/crypto/tls/key_schedule.go | 19 +- src/crypto/tls/ticket.go | 8 +- 13 files changed, 657 insertions(+), 503 deletions(-) diff --git a/src/crypto/tls/boring_test.go b/src/crypto/tls/boring_test.go index 59d4d2b2724ea..ba68f355eb037 100644 --- a/src/crypto/tls/boring_test.go +++ b/src/crypto/tls/boring_test.go @@ -269,7 +269,7 @@ func TestBoringClientHello(t *testing.T) { go Client(c, clientConfig).Handshake() srv := Server(s, testConfig) - msg, err := srv.readHandshake() + msg, err := srv.readHandshake(nil) if err != nil { t.Fatal(err) } diff --git a/src/crypto/tls/common.go b/src/crypto/tls/common.go index 007f0f47b233c..5394d64ac6c81 100644 --- a/src/crypto/tls/common.go +++ b/src/crypto/tls/common.go @@ -1394,7 +1394,7 @@ func (c *Certificate) leaf() (*x509.Certificate, error) { } type handshakeMessage interface { - marshal() []byte + marshal() ([]byte, error) unmarshal([]byte) bool } diff --git a/src/crypto/tls/conn.go b/src/crypto/tls/conn.go index 03c72633be1fd..1eefb17206f12 100644 --- a/src/crypto/tls/conn.go +++ b/src/crypto/tls/conn.go @@ -1004,18 +1004,37 @@ func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) { return n, nil } -// writeRecord writes a TLS record with the given type and payload to the -// connection and updates the record layer state. -func (c *Conn) writeRecord(typ recordType, data []byte) (int, error) { +// writeHandshakeRecord writes a handshake message to the connection and updates +// the record layer state. If transcript is non-nil the marshalled message is +// written to it. +func (c *Conn) writeHandshakeRecord(msg handshakeMessage, transcript transcriptHash) (int, error) { c.out.Lock() defer c.out.Unlock() - return c.writeRecordLocked(typ, data) + data, err := msg.marshal() + if err != nil { + return 0, err + } + if transcript != nil { + transcript.Write(data) + } + + return c.writeRecordLocked(recordTypeHandshake, data) +} + +// writeChangeCipherRecord writes a ChangeCipherSpec message to the connection and +// updates the record layer state. +func (c *Conn) writeChangeCipherRecord() error { + c.out.Lock() + defer c.out.Unlock() + _, err := c.writeRecordLocked(recordTypeChangeCipherSpec, []byte{1}) + return err } // readHandshake reads the next handshake message from -// the record layer. -func (c *Conn) readHandshake() (any, error) { +// the record layer. If transcript is non-nil, the message +// is written to the passed transcriptHash. +func (c *Conn) readHandshake(transcript transcriptHash) (any, error) { for c.hand.Len() < 4 { if err := c.readRecord(); err != nil { return nil, err @@ -1094,6 +1113,11 @@ func (c *Conn) readHandshake() (any, error) { if !m.unmarshal(data) { return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) } + + if transcript != nil { + transcript.Write(data) + } + return m, nil } @@ -1169,7 +1193,7 @@ func (c *Conn) handleRenegotiation() error { return errors.New("tls: internal error: unexpected renegotiation") } - msg, err := c.readHandshake() + msg, err := c.readHandshake(nil) if err != nil { return err } @@ -1215,7 +1239,7 @@ func (c *Conn) handlePostHandshakeMessage() error { return c.handleRenegotiation() } - msg, err := c.readHandshake() + msg, err := c.readHandshake(nil) if err != nil { return err } @@ -1251,7 +1275,11 @@ func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error { defer c.out.Unlock() msg := &keyUpdateMsg{} - _, err := c.writeRecordLocked(recordTypeHandshake, msg.marshal()) + msgBytes, err := msg.marshal() + if err != nil { + return err + } + _, err = c.writeRecordLocked(recordTypeHandshake, msgBytes) if err != nil { // Surface the error at the next write. c.out.setErrorLocked(err) diff --git a/src/crypto/tls/handshake_client.go b/src/crypto/tls/handshake_client.go index 7cf906c91d8ee..63d86b9f3a7ef 100644 --- a/src/crypto/tls/handshake_client.go +++ b/src/crypto/tls/handshake_client.go @@ -162,7 +162,10 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) { } c.serverName = hello.serverName - cacheKey, session, earlySecret, binderKey := c.loadSession(hello) + cacheKey, session, earlySecret, binderKey, err := c.loadSession(hello) + if err != nil { + return err + } if cacheKey != "" && session != nil { defer func() { // If we got a handshake failure when resuming a session, throw away @@ -177,11 +180,12 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) { }() } - if _, err := c.writeRecord(recordTypeHandshake, hello.marshal()); err != nil { + if _, err := c.writeHandshakeRecord(hello, nil); err != nil { return err } - msg, err := c.readHandshake() + // serverHelloMsg is not included in the transcript + msg, err := c.readHandshake(nil) if err != nil { return err } @@ -246,9 +250,9 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) { } func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string, - session *ClientSessionState, earlySecret, binderKey []byte) { + session *ClientSessionState, earlySecret, binderKey []byte, err error) { if c.config.SessionTicketsDisabled || c.config.ClientSessionCache == nil { - return "", nil, nil, nil + return "", nil, nil, nil, nil } hello.ticketSupported = true @@ -263,14 +267,14 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string, // renegotiation is primarily used to allow a client to send a client // certificate, which would be skipped if session resumption occurred. if c.handshakes != 0 { - return "", nil, nil, nil + return "", nil, nil, nil, nil } // Try to resume a previously negotiated TLS session, if available. cacheKey = clientSessionCacheKey(c.conn.RemoteAddr(), c.config) session, ok := c.config.ClientSessionCache.Get(cacheKey) if !ok || session == nil { - return cacheKey, nil, nil, nil + return cacheKey, nil, nil, nil, nil } // Check that version used for the previous session is still valid. @@ -282,7 +286,7 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string, } } if !versOk { - return cacheKey, nil, nil, nil + return cacheKey, nil, nil, nil, nil } // Check that the cached server certificate is not expired, and that it's @@ -291,16 +295,16 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string, if !c.config.InsecureSkipVerify { if len(session.verifiedChains) == 0 { // The original connection had InsecureSkipVerify, while this doesn't. - return cacheKey, nil, nil, nil + return cacheKey, nil, nil, nil, nil } serverCert := session.serverCertificates[0] if c.config.time().After(serverCert.NotAfter) { // Expired certificate, delete the entry. c.config.ClientSessionCache.Put(cacheKey, nil) - return cacheKey, nil, nil, nil + return cacheKey, nil, nil, nil, nil } if err := serverCert.VerifyHostname(c.config.ServerName); err != nil { - return cacheKey, nil, nil, nil + return cacheKey, nil, nil, nil, nil } } @@ -308,7 +312,7 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string, // In TLS 1.2 the cipher suite must match the resumed session. Ensure we // are still offering it. if mutualCipherSuite(hello.cipherSuites, session.cipherSuite) == nil { - return cacheKey, nil, nil, nil + return cacheKey, nil, nil, nil, nil } hello.sessionTicket = session.sessionTicket @@ -318,14 +322,14 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string, // Check that the session ticket is not expired. if c.config.time().After(session.useBy) { c.config.ClientSessionCache.Put(cacheKey, nil) - return cacheKey, nil, nil, nil + return cacheKey, nil, nil, nil, nil } // In TLS 1.3 the KDF hash must match the resumed session. Ensure we // offer at least one cipher suite with that hash. cipherSuite := cipherSuiteTLS13ByID(session.cipherSuite) if cipherSuite == nil { - return cacheKey, nil, nil, nil + return cacheKey, nil, nil, nil, nil } cipherSuiteOk := false for _, offeredID := range hello.cipherSuites { @@ -336,7 +340,7 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string, } } if !cipherSuiteOk { - return cacheKey, nil, nil, nil + return cacheKey, nil, nil, nil, nil } // Set the pre_shared_key extension. See RFC 8446, Section 4.2.11.1. @@ -354,9 +358,15 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string, earlySecret = cipherSuite.extract(psk, nil) binderKey = cipherSuite.deriveSecret(earlySecret, resumptionBinderLabel, nil) transcript := cipherSuite.hash.New() - transcript.Write(hello.marshalWithoutBinders()) + helloBytes, err := hello.marshalWithoutBinders() + if err != nil { + return "", nil, nil, nil, err + } + transcript.Write(helloBytes) pskBinders := [][]byte{cipherSuite.finishedHash(binderKey, transcript)} - hello.updateBinders(pskBinders) + if err := hello.updateBinders(pskBinders); err != nil { + return "", nil, nil, nil, err + } return } @@ -401,8 +411,12 @@ func (hs *clientHandshakeState) handshake() error { hs.finishedHash.discardHandshakeBuffer() } - hs.finishedHash.Write(hs.hello.marshal()) - hs.finishedHash.Write(hs.serverHello.marshal()) + if err := transcriptMsg(hs.hello, &hs.finishedHash); err != nil { + return err + } + if err := transcriptMsg(hs.serverHello, &hs.finishedHash); err != nil { + return err + } c.buffering = true c.didResume = isResume @@ -473,7 +487,7 @@ func (hs *clientHandshakeState) pickCipherSuite() error { func (hs *clientHandshakeState) doFullHandshake() error { c := hs.c - msg, err := c.readHandshake() + msg, err := c.readHandshake(&hs.finishedHash) if err != nil { return err } @@ -482,9 +496,8 @@ func (hs *clientHandshakeState) doFullHandshake() error { c.sendAlert(alertUnexpectedMessage) return unexpectedMessageError(certMsg, msg) } - hs.finishedHash.Write(certMsg.marshal()) - msg, err = c.readHandshake() + msg, err = c.readHandshake(&hs.finishedHash) if err != nil { return err } @@ -502,11 +515,10 @@ func (hs *clientHandshakeState) doFullHandshake() error { c.sendAlert(alertUnexpectedMessage) return errors.New("tls: received unexpected CertificateStatus message") } - hs.finishedHash.Write(cs.marshal()) c.ocspResponse = cs.response - msg, err = c.readHandshake() + msg, err = c.readHandshake(&hs.finishedHash) if err != nil { return err } @@ -535,14 +547,13 @@ func (hs *clientHandshakeState) doFullHandshake() error { skx, ok := msg.(*serverKeyExchangeMsg) if ok { - hs.finishedHash.Write(skx.marshal()) err = keyAgreement.processServerKeyExchange(c.config, hs.hello, hs.serverHello, c.peerCertificates[0], skx) if err != nil { c.sendAlert(alertUnexpectedMessage) return err } - msg, err = c.readHandshake() + msg, err = c.readHandshake(&hs.finishedHash) if err != nil { return err } @@ -553,7 +564,6 @@ func (hs *clientHandshakeState) doFullHandshake() error { certReq, ok := msg.(*certificateRequestMsg) if ok { certRequested = true - hs.finishedHash.Write(certReq.marshal()) cri := certificateRequestInfoFromMsg(hs.ctx, c.vers, certReq) if chainToSend, err = c.getClientCertificate(cri); err != nil { @@ -561,7 +571,7 @@ func (hs *clientHandshakeState) doFullHandshake() error { return err } - msg, err = c.readHandshake() + msg, err = c.readHandshake(&hs.finishedHash) if err != nil { return err } @@ -572,7 +582,6 @@ func (hs *clientHandshakeState) doFullHandshake() error { c.sendAlert(alertUnexpectedMessage) return unexpectedMessageError(shd, msg) } - hs.finishedHash.Write(shd.marshal()) // If the server requested a certificate then we have to send a // Certificate message, even if it's empty because we don't have a @@ -580,8 +589,7 @@ func (hs *clientHandshakeState) doFullHandshake() error { if certRequested { certMsg = new(certificateMsg) certMsg.certificates = chainToSend.Certificate - hs.finishedHash.Write(certMsg.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(certMsg, &hs.finishedHash); err != nil { return err } } @@ -592,8 +600,7 @@ func (hs *clientHandshakeState) doFullHandshake() error { return err } if ckx != nil { - hs.finishedHash.Write(ckx.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, ckx.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(ckx, &hs.finishedHash); err != nil { return err } } @@ -640,8 +647,7 @@ func (hs *clientHandshakeState) doFullHandshake() error { return err } - hs.finishedHash.Write(certVerify.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, certVerify.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(certVerify, &hs.finishedHash); err != nil { return err } } @@ -776,7 +782,10 @@ func (hs *clientHandshakeState) readFinished(out []byte) error { return err } - msg, err := c.readHandshake() + // finishedMsg is included in the transcript, but not until after we + // check the client version, since the state before this message was + // sent is used during verification. + msg, err := c.readHandshake(nil) if err != nil { return err } @@ -792,7 +801,11 @@ func (hs *clientHandshakeState) readFinished(out []byte) error { c.sendAlert(alertHandshakeFailure) return errors.New("tls: server's Finished message was incorrect") } - hs.finishedHash.Write(serverFinished.marshal()) + + if err := transcriptMsg(serverFinished, &hs.finishedHash); err != nil { + return err + } + copy(out, verify) return nil } @@ -803,7 +816,7 @@ func (hs *clientHandshakeState) readSessionTicket() error { } c := hs.c - msg, err := c.readHandshake() + msg, err := c.readHandshake(&hs.finishedHash) if err != nil { return err } @@ -812,7 +825,6 @@ func (hs *clientHandshakeState) readSessionTicket() error { c.sendAlert(alertUnexpectedMessage) return unexpectedMessageError(sessionTicketMsg, msg) } - hs.finishedHash.Write(sessionTicketMsg.marshal()) hs.session = &ClientSessionState{ sessionTicket: sessionTicketMsg.ticket, @@ -832,14 +844,13 @@ func (hs *clientHandshakeState) readSessionTicket() error { func (hs *clientHandshakeState) sendFinished(out []byte) error { c := hs.c - if _, err := c.writeRecord(recordTypeChangeCipherSpec, []byte{1}); err != nil { + if err := c.writeChangeCipherRecord(); err != nil { return err } finished := new(finishedMsg) finished.verifyData = hs.finishedHash.clientSum(hs.masterSecret) - hs.finishedHash.Write(finished.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(finished, &hs.finishedHash); err != nil { return err } copy(out, finished.verifyData) diff --git a/src/crypto/tls/handshake_client_test.go b/src/crypto/tls/handshake_client_test.go index 380de9f6fb50e..749c9fc95452a 100644 --- a/src/crypto/tls/handshake_client_test.go +++ b/src/crypto/tls/handshake_client_test.go @@ -1257,7 +1257,7 @@ func TestServerSelectingUnconfiguredApplicationProtocol(t *testing.T) { cipherSuite: TLS_RSA_WITH_AES_128_GCM_SHA256, alpnProtocol: "how-about-this", } - serverHelloBytes := serverHello.marshal() + serverHelloBytes := mustMarshal(t, serverHello) s.Write([]byte{ byte(recordTypeHandshake), @@ -1500,7 +1500,7 @@ func TestServerSelectingUnconfiguredCipherSuite(t *testing.T) { random: make([]byte, 32), cipherSuite: TLS_RSA_WITH_AES_256_GCM_SHA384, } - serverHelloBytes := serverHello.marshal() + serverHelloBytes := mustMarshal(t, serverHello) s.Write([]byte{ byte(recordTypeHandshake), diff --git a/src/crypto/tls/handshake_client_tls13.go b/src/crypto/tls/handshake_client_tls13.go index 3bdd9373d668e..fefba01a0611a 100644 --- a/src/crypto/tls/handshake_client_tls13.go +++ b/src/crypto/tls/handshake_client_tls13.go @@ -62,7 +62,10 @@ func (hs *clientHandshakeStateTLS13) handshake() error { } hs.transcript = hs.suite.hash.New() - hs.transcript.Write(hs.hello.marshal()) + + if err := transcriptMsg(hs.hello, hs.transcript); err != nil { + return err + } if bytes.Equal(hs.serverHello.random, helloRetryRequestRandom) { if err := hs.sendDummyChangeCipherSpec(); err != nil { @@ -73,7 +76,9 @@ func (hs *clientHandshakeStateTLS13) handshake() error { } } - hs.transcript.Write(hs.serverHello.marshal()) + if err := transcriptMsg(hs.serverHello, hs.transcript); err != nil { + return err + } c.buffering = true if err := hs.processServerHello(); err != nil { @@ -172,8 +177,7 @@ func (hs *clientHandshakeStateTLS13) sendDummyChangeCipherSpec() error { } hs.sentDummyCCS = true - _, err := hs.c.writeRecord(recordTypeChangeCipherSpec, []byte{1}) - return err + return hs.c.writeChangeCipherRecord() } // processHelloRetryRequest handles the HRR in hs.serverHello, modifies and @@ -188,7 +192,9 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error { hs.transcript.Reset() hs.transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))}) hs.transcript.Write(chHash) - hs.transcript.Write(hs.serverHello.marshal()) + if err := transcriptMsg(hs.serverHello, hs.transcript); err != nil { + return err + } // The only HelloRetryRequest extensions we support are key_share and // cookie, and clients must abort the handshake if the HRR would not result @@ -253,10 +259,18 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error { transcript := hs.suite.hash.New() transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))}) transcript.Write(chHash) - transcript.Write(hs.serverHello.marshal()) - transcript.Write(hs.hello.marshalWithoutBinders()) + if err := transcriptMsg(hs.serverHello, hs.transcript); err != nil { + return err + } + helloBytes, err := hs.hello.marshalWithoutBinders() + if err != nil { + return err + } + transcript.Write(helloBytes) pskBinders := [][]byte{hs.suite.finishedHash(hs.binderKey, transcript)} - hs.hello.updateBinders(pskBinders) + if err := hs.hello.updateBinders(pskBinders); err != nil { + return err + } } else { // Server selected a cipher suite incompatible with the PSK. hs.hello.pskIdentities = nil @@ -264,12 +278,12 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error { } } - hs.transcript.Write(hs.hello.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(hs.hello, hs.transcript); err != nil { return err } - msg, err := c.readHandshake() + // serverHelloMsg is not included in the transcript + msg, err := c.readHandshake(nil) if err != nil { return err } @@ -363,6 +377,7 @@ func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error { if !hs.usingPSK { earlySecret = hs.suite.extract(nil, nil) } + handshakeSecret := hs.suite.extract(sharedKey, hs.suite.deriveSecret(earlySecret, "derived", nil)) @@ -393,7 +408,7 @@ func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error { func (hs *clientHandshakeStateTLS13) readServerParameters() error { c := hs.c - msg, err := c.readHandshake() + msg, err := c.readHandshake(hs.transcript) if err != nil { return err } @@ -403,7 +418,6 @@ func (hs *clientHandshakeStateTLS13) readServerParameters() error { c.sendAlert(alertUnexpectedMessage) return unexpectedMessageError(encryptedExtensions, msg) } - hs.transcript.Write(encryptedExtensions.marshal()) if err := checkALPN(hs.hello.alpnProtocols, encryptedExtensions.alpnProtocol); err != nil { c.sendAlert(alertUnsupportedExtension) @@ -432,18 +446,16 @@ func (hs *clientHandshakeStateTLS13) readServerCertificate() error { return nil } - msg, err := c.readHandshake() + msg, err := c.readHandshake(hs.transcript) if err != nil { return err } certReq, ok := msg.(*certificateRequestMsgTLS13) if ok { - hs.transcript.Write(certReq.marshal()) - hs.certReq = certReq - msg, err = c.readHandshake() + msg, err = c.readHandshake(hs.transcript) if err != nil { return err } @@ -458,7 +470,6 @@ func (hs *clientHandshakeStateTLS13) readServerCertificate() error { c.sendAlert(alertDecodeError) return errors.New("tls: received empty certificates message") } - hs.transcript.Write(certMsg.marshal()) c.scts = certMsg.certificate.SignedCertificateTimestamps c.ocspResponse = certMsg.certificate.OCSPStaple @@ -467,7 +478,10 @@ func (hs *clientHandshakeStateTLS13) readServerCertificate() error { return err } - msg, err = c.readHandshake() + // certificateVerifyMsg is included in the transcript, but not until + // after we verify the handshake signature, since the state before + // this message was sent is used. + msg, err = c.readHandshake(nil) if err != nil { return err } @@ -498,7 +512,9 @@ func (hs *clientHandshakeStateTLS13) readServerCertificate() error { return errors.New("tls: invalid signature by the server certificate: " + err.Error()) } - hs.transcript.Write(certVerify.marshal()) + if err := transcriptMsg(certVerify, hs.transcript); err != nil { + return err + } return nil } @@ -506,7 +522,10 @@ func (hs *clientHandshakeStateTLS13) readServerCertificate() error { func (hs *clientHandshakeStateTLS13) readServerFinished() error { c := hs.c - msg, err := c.readHandshake() + // finishedMsg is included in the transcript, but not until after we + // check the client version, since the state before this message was + // sent is used during verification. + msg, err := c.readHandshake(nil) if err != nil { return err } @@ -523,7 +542,9 @@ func (hs *clientHandshakeStateTLS13) readServerFinished() error { return errors.New("tls: invalid server finished hash") } - hs.transcript.Write(finished.marshal()) + if err := transcriptMsg(finished, hs.transcript); err != nil { + return err + } // Derive secrets that take context through the server Finished. @@ -572,8 +593,7 @@ func (hs *clientHandshakeStateTLS13) sendClientCertificate() error { certMsg.scts = hs.certReq.scts && len(cert.SignedCertificateTimestamps) > 0 certMsg.ocspStapling = hs.certReq.ocspStapling && len(cert.OCSPStaple) > 0 - hs.transcript.Write(certMsg.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(certMsg, hs.transcript); err != nil { return err } @@ -610,8 +630,7 @@ func (hs *clientHandshakeStateTLS13) sendClientCertificate() error { } certVerifyMsg.signature = sig - hs.transcript.Write(certVerifyMsg.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, certVerifyMsg.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(certVerifyMsg, hs.transcript); err != nil { return err } @@ -625,8 +644,7 @@ func (hs *clientHandshakeStateTLS13) sendClientFinished() error { verifyData: hs.suite.finishedHash(c.out.trafficSecret, hs.transcript), } - hs.transcript.Write(finished.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(finished, hs.transcript); err != nil { return err } diff --git a/src/crypto/tls/handshake_messages.go b/src/crypto/tls/handshake_messages.go index 7ab0f100b8bce..695aacf126a15 100644 --- a/src/crypto/tls/handshake_messages.go +++ b/src/crypto/tls/handshake_messages.go @@ -5,6 +5,7 @@ package tls import ( + "errors" "fmt" "strings" @@ -94,9 +95,181 @@ type clientHelloMsg struct { pskBinders [][]byte } -func (m *clientHelloMsg) marshal() []byte { +func (m *clientHelloMsg) marshal() ([]byte, error) { if m.raw != nil { - return m.raw + return m.raw, nil + } + + var exts cryptobyte.Builder + if len(m.serverName) > 0 { + // RFC 6066, Section 3 + exts.AddUint16(extensionServerName) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint8(0) // name_type = host_name + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes([]byte(m.serverName)) + }) + }) + }) + } + if m.ocspStapling { + // RFC 4366, Section 3.6 + exts.AddUint16(extensionStatusRequest) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint8(1) // status_type = ocsp + exts.AddUint16(0) // empty responder_id_list + exts.AddUint16(0) // empty request_extensions + }) + } + if len(m.supportedCurves) > 0 { + // RFC 4492, sections 5.1.1 and RFC 8446, Section 4.2.7 + exts.AddUint16(extensionSupportedCurves) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + for _, curve := range m.supportedCurves { + exts.AddUint16(uint16(curve)) + } + }) + }) + } + if len(m.supportedPoints) > 0 { + // RFC 4492, Section 5.1.2 + exts.AddUint16(extensionSupportedPoints) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes(m.supportedPoints) + }) + }) + } + if m.ticketSupported { + // RFC 5077, Section 3.2 + exts.AddUint16(extensionSessionTicket) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes(m.sessionTicket) + }) + } + if len(m.supportedSignatureAlgorithms) > 0 { + // RFC 5246, Section 7.4.1.4.1 + exts.AddUint16(extensionSignatureAlgorithms) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + for _, sigAlgo := range m.supportedSignatureAlgorithms { + exts.AddUint16(uint16(sigAlgo)) + } + }) + }) + } + if len(m.supportedSignatureAlgorithmsCert) > 0 { + // RFC 8446, Section 4.2.3 + exts.AddUint16(extensionSignatureAlgorithmsCert) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + for _, sigAlgo := range m.supportedSignatureAlgorithmsCert { + exts.AddUint16(uint16(sigAlgo)) + } + }) + }) + } + if m.secureRenegotiationSupported { + // RFC 5746, Section 3.2 + exts.AddUint16(extensionRenegotiationInfo) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes(m.secureRenegotiation) + }) + }) + } + if len(m.alpnProtocols) > 0 { + // RFC 7301, Section 3.1 + exts.AddUint16(extensionALPN) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + for _, proto := range m.alpnProtocols { + exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes([]byte(proto)) + }) + } + }) + }) + } + if m.scts { + // RFC 6962, Section 3.3.1 + exts.AddUint16(extensionSCT) + exts.AddUint16(0) // empty extension_data + } + if len(m.supportedVersions) > 0 { + // RFC 8446, Section 4.2.1 + exts.AddUint16(extensionSupportedVersions) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { + for _, vers := range m.supportedVersions { + exts.AddUint16(vers) + } + }) + }) + } + if len(m.cookie) > 0 { + // RFC 8446, Section 4.2.2 + exts.AddUint16(extensionCookie) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes(m.cookie) + }) + }) + } + if len(m.keyShares) > 0 { + // RFC 8446, Section 4.2.8 + exts.AddUint16(extensionKeyShare) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + for _, ks := range m.keyShares { + exts.AddUint16(uint16(ks.group)) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes(ks.data) + }) + } + }) + }) + } + if m.earlyData { + // RFC 8446, Section 4.2.10 + exts.AddUint16(extensionEarlyData) + exts.AddUint16(0) // empty extension_data + } + if len(m.pskModes) > 0 { + // RFC 8446, Section 4.2.9 + exts.AddUint16(extensionPSKModes) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes(m.pskModes) + }) + }) + } + if len(m.pskIdentities) > 0 { // pre_shared_key must be the last extension + // RFC 8446, Section 4.2.11 + exts.AddUint16(extensionPreSharedKey) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + for _, psk := range m.pskIdentities { + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes(psk.label) + }) + exts.AddUint32(psk.obfuscatedTicketAge) + } + }) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + for _, binder := range m.pskBinders { + exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes(binder) + }) + } + }) + }) + } + extBytes, err := exts.Bytes() + if err != nil { + return nil, err } var b cryptobyte.Builder @@ -116,219 +289,53 @@ func (m *clientHelloMsg) marshal() []byte { b.AddBytes(m.compressionMethods) }) - // If extensions aren't present, omit them. - var extensionsPresent bool - bWithoutExtensions := *b - - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - if len(m.serverName) > 0 { - // RFC 6066, Section 3 - b.AddUint16(extensionServerName) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8(0) // name_type = host_name - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes([]byte(m.serverName)) - }) - }) - }) - } - if m.ocspStapling { - // RFC 4366, Section 3.6 - b.AddUint16(extensionStatusRequest) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8(1) // status_type = ocsp - b.AddUint16(0) // empty responder_id_list - b.AddUint16(0) // empty request_extensions - }) - } - if len(m.supportedCurves) > 0 { - // RFC 4492, sections 5.1.1 and RFC 8446, Section 4.2.7 - b.AddUint16(extensionSupportedCurves) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, curve := range m.supportedCurves { - b.AddUint16(uint16(curve)) - } - }) - }) - } - if len(m.supportedPoints) > 0 { - // RFC 4492, Section 5.1.2 - b.AddUint16(extensionSupportedPoints) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.supportedPoints) - }) - }) - } - if m.ticketSupported { - // RFC 5077, Section 3.2 - b.AddUint16(extensionSessionTicket) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.sessionTicket) - }) - } - if len(m.supportedSignatureAlgorithms) > 0 { - // RFC 5246, Section 7.4.1.4.1 - b.AddUint16(extensionSignatureAlgorithms) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, sigAlgo := range m.supportedSignatureAlgorithms { - b.AddUint16(uint16(sigAlgo)) - } - }) - }) - } - if len(m.supportedSignatureAlgorithmsCert) > 0 { - // RFC 8446, Section 4.2.3 - b.AddUint16(extensionSignatureAlgorithmsCert) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, sigAlgo := range m.supportedSignatureAlgorithmsCert { - b.AddUint16(uint16(sigAlgo)) - } - }) - }) - } - if m.secureRenegotiationSupported { - // RFC 5746, Section 3.2 - b.AddUint16(extensionRenegotiationInfo) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.secureRenegotiation) - }) - }) - } - if len(m.alpnProtocols) > 0 { - // RFC 7301, Section 3.1 - b.AddUint16(extensionALPN) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, proto := range m.alpnProtocols { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes([]byte(proto)) - }) - } - }) - }) - } - if m.scts { - // RFC 6962, Section 3.3.1 - b.AddUint16(extensionSCT) - b.AddUint16(0) // empty extension_data - } - if len(m.supportedVersions) > 0 { - // RFC 8446, Section 4.2.1 - b.AddUint16(extensionSupportedVersions) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - for _, vers := range m.supportedVersions { - b.AddUint16(vers) - } - }) - }) - } - if len(m.cookie) > 0 { - // RFC 8446, Section 4.2.2 - b.AddUint16(extensionCookie) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.cookie) - }) - }) - } - if len(m.keyShares) > 0 { - // RFC 8446, Section 4.2.8 - b.AddUint16(extensionKeyShare) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, ks := range m.keyShares { - b.AddUint16(uint16(ks.group)) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(ks.data) - }) - } - }) - }) - } - if m.earlyData { - // RFC 8446, Section 4.2.10 - b.AddUint16(extensionEarlyData) - b.AddUint16(0) // empty extension_data - } - if len(m.pskModes) > 0 { - // RFC 8446, Section 4.2.9 - b.AddUint16(extensionPSKModes) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.pskModes) - }) - }) - } - if len(m.pskIdentities) > 0 { // pre_shared_key must be the last extension - // RFC 8446, Section 4.2.11 - b.AddUint16(extensionPreSharedKey) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, psk := range m.pskIdentities { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(psk.label) - }) - b.AddUint32(psk.obfuscatedTicketAge) - } - }) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, binder := range m.pskBinders { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(binder) - }) - } - }) - }) - } - - extensionsPresent = len(b.BytesOrPanic()) > 2 - }) - - if !extensionsPresent { - *b = bWithoutExtensions + if len(extBytes) > 0 { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(extBytes) + }) } }) - m.raw = b.BytesOrPanic() - return m.raw + m.raw, err = b.Bytes() + return m.raw, err } // marshalWithoutBinders returns the ClientHello through the // PreSharedKeyExtension.identities field, according to RFC 8446, Section // 4.2.11.2. Note that m.pskBinders must be set to slices of the correct length. -func (m *clientHelloMsg) marshalWithoutBinders() []byte { +func (m *clientHelloMsg) marshalWithoutBinders() ([]byte, error) { bindersLen := 2 // uint16 length prefix for _, binder := range m.pskBinders { bindersLen += 1 // uint8 length prefix bindersLen += len(binder) } - fullMessage := m.marshal() - return fullMessage[:len(fullMessage)-bindersLen] + fullMessage, err := m.marshal() + if err != nil { + return nil, err + } + return fullMessage[:len(fullMessage)-bindersLen], nil } // updateBinders updates the m.pskBinders field, if necessary updating the // cached marshaled representation. The supplied binders must have the same // length as the current m.pskBinders. -func (m *clientHelloMsg) updateBinders(pskBinders [][]byte) { +func (m *clientHelloMsg) updateBinders(pskBinders [][]byte) error { if len(pskBinders) != len(m.pskBinders) { - panic("tls: internal error: pskBinders length mismatch") + return errors.New("tls: internal error: pskBinders length mismatch") } for i := range m.pskBinders { if len(pskBinders[i]) != len(m.pskBinders[i]) { - panic("tls: internal error: pskBinders length mismatch") + return errors.New("tls: internal error: pskBinders length mismatch") } } m.pskBinders = pskBinders if m.raw != nil { - lenWithoutBinders := len(m.marshalWithoutBinders()) + helloBytes, err := m.marshalWithoutBinders() + if err != nil { + return err + } + lenWithoutBinders := len(helloBytes) b := cryptobyte.NewFixedBuilder(m.raw[:lenWithoutBinders]) b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { for _, binder := range m.pskBinders { @@ -338,9 +345,11 @@ func (m *clientHelloMsg) updateBinders(pskBinders [][]byte) { } }) if out, err := b.Bytes(); err != nil || len(out) != len(m.raw) { - panic("tls: internal error: failed to update binders") + return errors.New("tls: internal error: failed to update binders") } } + + return nil } func (m *clientHelloMsg) unmarshal(data []byte) bool { @@ -618,9 +627,98 @@ type serverHelloMsg struct { selectedGroup CurveID } -func (m *serverHelloMsg) marshal() []byte { +func (m *serverHelloMsg) marshal() ([]byte, error) { if m.raw != nil { - return m.raw + return m.raw, nil + } + + var exts cryptobyte.Builder + if m.ocspStapling { + exts.AddUint16(extensionStatusRequest) + exts.AddUint16(0) // empty extension_data + } + if m.ticketSupported { + exts.AddUint16(extensionSessionTicket) + exts.AddUint16(0) // empty extension_data + } + if m.secureRenegotiationSupported { + exts.AddUint16(extensionRenegotiationInfo) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes(m.secureRenegotiation) + }) + }) + } + if len(m.alpnProtocol) > 0 { + exts.AddUint16(extensionALPN) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes([]byte(m.alpnProtocol)) + }) + }) + }) + } + if len(m.scts) > 0 { + exts.AddUint16(extensionSCT) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + for _, sct := range m.scts { + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes(sct) + }) + } + }) + }) + } + if m.supportedVersion != 0 { + exts.AddUint16(extensionSupportedVersions) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint16(m.supportedVersion) + }) + } + if m.serverShare.group != 0 { + exts.AddUint16(extensionKeyShare) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint16(uint16(m.serverShare.group)) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes(m.serverShare.data) + }) + }) + } + if m.selectedIdentityPresent { + exts.AddUint16(extensionPreSharedKey) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint16(m.selectedIdentity) + }) + } + + if len(m.cookie) > 0 { + exts.AddUint16(extensionCookie) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes(m.cookie) + }) + }) + } + if m.selectedGroup != 0 { + exts.AddUint16(extensionKeyShare) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint16(uint16(m.selectedGroup)) + }) + } + if len(m.supportedPoints) > 0 { + exts.AddUint16(extensionSupportedPoints) + exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes(m.supportedPoints) + }) + }) + } + + extBytes, err := exts.Bytes() + if err != nil { + return nil, err } var b cryptobyte.Builder @@ -634,104 +732,15 @@ func (m *serverHelloMsg) marshal() []byte { b.AddUint16(m.cipherSuite) b.AddUint8(m.compressionMethod) - // If extensions aren't present, omit them. - var extensionsPresent bool - bWithoutExtensions := *b - - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - if m.ocspStapling { - b.AddUint16(extensionStatusRequest) - b.AddUint16(0) // empty extension_data - } - if m.ticketSupported { - b.AddUint16(extensionSessionTicket) - b.AddUint16(0) // empty extension_data - } - if m.secureRenegotiationSupported { - b.AddUint16(extensionRenegotiationInfo) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.secureRenegotiation) - }) - }) - } - if len(m.alpnProtocol) > 0 { - b.AddUint16(extensionALPN) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes([]byte(m.alpnProtocol)) - }) - }) - }) - } - if len(m.scts) > 0 { - b.AddUint16(extensionSCT) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - for _, sct := range m.scts { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(sct) - }) - } - }) - }) - } - if m.supportedVersion != 0 { - b.AddUint16(extensionSupportedVersions) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16(m.supportedVersion) - }) - } - if m.serverShare.group != 0 { - b.AddUint16(extensionKeyShare) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16(uint16(m.serverShare.group)) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.serverShare.data) - }) - }) - } - if m.selectedIdentityPresent { - b.AddUint16(extensionPreSharedKey) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16(m.selectedIdentity) - }) - } - - if len(m.cookie) > 0 { - b.AddUint16(extensionCookie) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.cookie) - }) - }) - } - if m.selectedGroup != 0 { - b.AddUint16(extensionKeyShare) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint16(uint16(m.selectedGroup)) - }) - } - if len(m.supportedPoints) > 0 { - b.AddUint16(extensionSupportedPoints) - b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { - b.AddBytes(m.supportedPoints) - }) - }) - } - - extensionsPresent = len(b.BytesOrPanic()) > 2 - }) - - if !extensionsPresent { - *b = bWithoutExtensions + if len(extBytes) > 0 { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(extBytes) + }) } }) - m.raw = b.BytesOrPanic() - return m.raw + m.raw, err = b.Bytes() + return m.raw, err } func (m *serverHelloMsg) unmarshal(data []byte) bool { @@ -855,9 +864,9 @@ type encryptedExtensionsMsg struct { alpnProtocol string } -func (m *encryptedExtensionsMsg) marshal() []byte { +func (m *encryptedExtensionsMsg) marshal() ([]byte, error) { if m.raw != nil { - return m.raw + return m.raw, nil } var b cryptobyte.Builder @@ -877,8 +886,9 @@ func (m *encryptedExtensionsMsg) marshal() []byte { }) }) - m.raw = b.BytesOrPanic() - return m.raw + var err error + m.raw, err = b.Bytes() + return m.raw, err } func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool { @@ -926,10 +936,10 @@ func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool { type endOfEarlyDataMsg struct{} -func (m *endOfEarlyDataMsg) marshal() []byte { +func (m *endOfEarlyDataMsg) marshal() ([]byte, error) { x := make([]byte, 4) x[0] = typeEndOfEarlyData - return x + return x, nil } func (m *endOfEarlyDataMsg) unmarshal(data []byte) bool { @@ -941,9 +951,9 @@ type keyUpdateMsg struct { updateRequested bool } -func (m *keyUpdateMsg) marshal() []byte { +func (m *keyUpdateMsg) marshal() ([]byte, error) { if m.raw != nil { - return m.raw + return m.raw, nil } var b cryptobyte.Builder @@ -956,8 +966,9 @@ func (m *keyUpdateMsg) marshal() []byte { } }) - m.raw = b.BytesOrPanic() - return m.raw + var err error + m.raw, err = b.Bytes() + return m.raw, err } func (m *keyUpdateMsg) unmarshal(data []byte) bool { @@ -989,9 +1000,9 @@ type newSessionTicketMsgTLS13 struct { maxEarlyData uint32 } -func (m *newSessionTicketMsgTLS13) marshal() []byte { +func (m *newSessionTicketMsgTLS13) marshal() ([]byte, error) { if m.raw != nil { - return m.raw + return m.raw, nil } var b cryptobyte.Builder @@ -1016,8 +1027,9 @@ func (m *newSessionTicketMsgTLS13) marshal() []byte { }) }) - m.raw = b.BytesOrPanic() - return m.raw + var err error + m.raw, err = b.Bytes() + return m.raw, err } func (m *newSessionTicketMsgTLS13) unmarshal(data []byte) bool { @@ -1070,9 +1082,9 @@ type certificateRequestMsgTLS13 struct { certificateAuthorities [][]byte } -func (m *certificateRequestMsgTLS13) marshal() []byte { +func (m *certificateRequestMsgTLS13) marshal() ([]byte, error) { if m.raw != nil { - return m.raw + return m.raw, nil } var b cryptobyte.Builder @@ -1131,8 +1143,9 @@ func (m *certificateRequestMsgTLS13) marshal() []byte { }) }) - m.raw = b.BytesOrPanic() - return m.raw + var err error + m.raw, err = b.Bytes() + return m.raw, err } func (m *certificateRequestMsgTLS13) unmarshal(data []byte) bool { @@ -1216,9 +1229,9 @@ type certificateMsg struct { certificates [][]byte } -func (m *certificateMsg) marshal() (x []byte) { +func (m *certificateMsg) marshal() ([]byte, error) { if m.raw != nil { - return m.raw + return m.raw, nil } var i int @@ -1227,7 +1240,7 @@ func (m *certificateMsg) marshal() (x []byte) { } length := 3 + 3*len(m.certificates) + i - x = make([]byte, 4+length) + x := make([]byte, 4+length) x[0] = typeCertificate x[1] = uint8(length >> 16) x[2] = uint8(length >> 8) @@ -1248,7 +1261,7 @@ func (m *certificateMsg) marshal() (x []byte) { } m.raw = x - return + return m.raw, nil } func (m *certificateMsg) unmarshal(data []byte) bool { @@ -1295,9 +1308,9 @@ type certificateMsgTLS13 struct { scts bool } -func (m *certificateMsgTLS13) marshal() []byte { +func (m *certificateMsgTLS13) marshal() ([]byte, error) { if m.raw != nil { - return m.raw + return m.raw, nil } var b cryptobyte.Builder @@ -1315,8 +1328,9 @@ func (m *certificateMsgTLS13) marshal() []byte { marshalCertificate(b, certificate) }) - m.raw = b.BytesOrPanic() - return m.raw + var err error + m.raw, err = b.Bytes() + return m.raw, err } func marshalCertificate(b *cryptobyte.Builder, certificate Certificate) { @@ -1439,9 +1453,9 @@ type serverKeyExchangeMsg struct { key []byte } -func (m *serverKeyExchangeMsg) marshal() []byte { +func (m *serverKeyExchangeMsg) marshal() ([]byte, error) { if m.raw != nil { - return m.raw + return m.raw, nil } length := len(m.key) x := make([]byte, length+4) @@ -1452,7 +1466,7 @@ func (m *serverKeyExchangeMsg) marshal() []byte { copy(x[4:], m.key) m.raw = x - return x + return x, nil } func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool { @@ -1469,9 +1483,9 @@ type certificateStatusMsg struct { response []byte } -func (m *certificateStatusMsg) marshal() []byte { +func (m *certificateStatusMsg) marshal() ([]byte, error) { if m.raw != nil { - return m.raw + return m.raw, nil } var b cryptobyte.Builder @@ -1483,8 +1497,9 @@ func (m *certificateStatusMsg) marshal() []byte { }) }) - m.raw = b.BytesOrPanic() - return m.raw + var err error + m.raw, err = b.Bytes() + return m.raw, err } func (m *certificateStatusMsg) unmarshal(data []byte) bool { @@ -1503,10 +1518,10 @@ func (m *certificateStatusMsg) unmarshal(data []byte) bool { type serverHelloDoneMsg struct{} -func (m *serverHelloDoneMsg) marshal() []byte { +func (m *serverHelloDoneMsg) marshal() ([]byte, error) { x := make([]byte, 4) x[0] = typeServerHelloDone - return x + return x, nil } func (m *serverHelloDoneMsg) unmarshal(data []byte) bool { @@ -1518,9 +1533,9 @@ type clientKeyExchangeMsg struct { ciphertext []byte } -func (m *clientKeyExchangeMsg) marshal() []byte { +func (m *clientKeyExchangeMsg) marshal() ([]byte, error) { if m.raw != nil { - return m.raw + return m.raw, nil } length := len(m.ciphertext) x := make([]byte, length+4) @@ -1531,7 +1546,7 @@ func (m *clientKeyExchangeMsg) marshal() []byte { copy(x[4:], m.ciphertext) m.raw = x - return x + return x, nil } func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool { @@ -1552,9 +1567,9 @@ type finishedMsg struct { verifyData []byte } -func (m *finishedMsg) marshal() []byte { +func (m *finishedMsg) marshal() ([]byte, error) { if m.raw != nil { - return m.raw + return m.raw, nil } var b cryptobyte.Builder @@ -1563,8 +1578,9 @@ func (m *finishedMsg) marshal() []byte { b.AddBytes(m.verifyData) }) - m.raw = b.BytesOrPanic() - return m.raw + var err error + m.raw, err = b.Bytes() + return m.raw, err } func (m *finishedMsg) unmarshal(data []byte) bool { @@ -1586,9 +1602,9 @@ type certificateRequestMsg struct { certificateAuthorities [][]byte } -func (m *certificateRequestMsg) marshal() (x []byte) { +func (m *certificateRequestMsg) marshal() ([]byte, error) { if m.raw != nil { - return m.raw + return m.raw, nil } // See RFC 4346, Section 7.4.4. @@ -1603,7 +1619,7 @@ func (m *certificateRequestMsg) marshal() (x []byte) { length += 2 + 2*len(m.supportedSignatureAlgorithms) } - x = make([]byte, 4+length) + x := make([]byte, 4+length) x[0] = typeCertificateRequest x[1] = uint8(length >> 16) x[2] = uint8(length >> 8) @@ -1638,7 +1654,7 @@ func (m *certificateRequestMsg) marshal() (x []byte) { } m.raw = x - return + return m.raw, nil } func (m *certificateRequestMsg) unmarshal(data []byte) bool { @@ -1724,9 +1740,9 @@ type certificateVerifyMsg struct { signature []byte } -func (m *certificateVerifyMsg) marshal() (x []byte) { +func (m *certificateVerifyMsg) marshal() ([]byte, error) { if m.raw != nil { - return m.raw + return m.raw, nil } var b cryptobyte.Builder @@ -1740,8 +1756,9 @@ func (m *certificateVerifyMsg) marshal() (x []byte) { }) }) - m.raw = b.BytesOrPanic() - return m.raw + var err error + m.raw, err = b.Bytes() + return m.raw, err } func (m *certificateVerifyMsg) unmarshal(data []byte) bool { @@ -1764,15 +1781,15 @@ type newSessionTicketMsg struct { ticket []byte } -func (m *newSessionTicketMsg) marshal() (x []byte) { +func (m *newSessionTicketMsg) marshal() ([]byte, error) { if m.raw != nil { - return m.raw + return m.raw, nil } // See RFC 5077, Section 3.3. ticketLen := len(m.ticket) length := 2 + 4 + ticketLen - x = make([]byte, 4+length) + x := make([]byte, 4+length) x[0] = typeNewSessionTicket x[1] = uint8(length >> 16) x[2] = uint8(length >> 8) @@ -1783,7 +1800,7 @@ func (m *newSessionTicketMsg) marshal() (x []byte) { m.raw = x - return + return m.raw, nil } func (m *newSessionTicketMsg) unmarshal(data []byte) bool { @@ -1811,10 +1828,25 @@ func (m *newSessionTicketMsg) unmarshal(data []byte) bool { type helloRequestMsg struct { } -func (*helloRequestMsg) marshal() []byte { - return []byte{typeHelloRequest, 0, 0, 0} +func (*helloRequestMsg) marshal() ([]byte, error) { + return []byte{typeHelloRequest, 0, 0, 0}, nil } func (*helloRequestMsg) unmarshal(data []byte) bool { return len(data) == 4 } + +type transcriptHash interface { + Write([]byte) (int, error) +} + +// transcriptMsg is a helper used to marshal and hash messages which typically +// are not written to the wire, and as such aren't hashed during Conn.writeRecord. +func transcriptMsg(msg handshakeMessage, h transcriptHash) error { + data, err := msg.marshal() + if err != nil { + return err + } + h.Write(data) + return nil +} diff --git a/src/crypto/tls/handshake_messages_test.go b/src/crypto/tls/handshake_messages_test.go index c6fc8f2bf3783..206e2fb024feb 100644 --- a/src/crypto/tls/handshake_messages_test.go +++ b/src/crypto/tls/handshake_messages_test.go @@ -38,6 +38,15 @@ var tests = []any{ &certificateMsgTLS13{}, } +func mustMarshal(t *testing.T, msg handshakeMessage) []byte { + t.Helper() + b, err := msg.marshal() + if err != nil { + t.Fatal(err) + } + return b +} + func TestMarshalUnmarshal(t *testing.T) { rand := rand.New(rand.NewSource(time.Now().UnixNano())) @@ -56,7 +65,7 @@ func TestMarshalUnmarshal(t *testing.T) { } m1 := v.Interface().(handshakeMessage) - marshaled := m1.marshal() + marshaled := mustMarshal(t, m1) m2 := iface.(handshakeMessage) if !m2.unmarshal(marshaled) { t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled) @@ -409,12 +418,12 @@ func TestRejectEmptySCTList(t *testing.T) { var random [32]byte sct := []byte{0x42, 0x42, 0x42, 0x42} - serverHello := serverHelloMsg{ + serverHello := &serverHelloMsg{ vers: VersionTLS12, random: random[:], scts: [][]byte{sct}, } - serverHelloBytes := serverHello.marshal() + serverHelloBytes := mustMarshal(t, serverHello) var serverHelloCopy serverHelloMsg if !serverHelloCopy.unmarshal(serverHelloBytes) { @@ -452,12 +461,12 @@ func TestRejectEmptySCT(t *testing.T) { // not be zero length. var random [32]byte - serverHello := serverHelloMsg{ + serverHello := &serverHelloMsg{ vers: VersionTLS12, random: random[:], scts: [][]byte{nil}, } - serverHelloBytes := serverHello.marshal() + serverHelloBytes := mustMarshal(t, serverHello) var serverHelloCopy serverHelloMsg if serverHelloCopy.unmarshal(serverHelloBytes) { diff --git a/src/crypto/tls/handshake_server.go b/src/crypto/tls/handshake_server.go index 682cfc20619f6..e22f284cfb428 100644 --- a/src/crypto/tls/handshake_server.go +++ b/src/crypto/tls/handshake_server.go @@ -128,7 +128,9 @@ func (hs *serverHandshakeState) handshake() error { // readClientHello reads a ClientHello message and selects the protocol version. func (c *Conn) readClientHello(ctx context.Context) (*clientHelloMsg, error) { - msg, err := c.readHandshake() + // clientHelloMsg is included in the transcript, but we haven't initialized + // it yet. The respective handshake functions will record it themselves. + msg, err := c.readHandshake(nil) if err != nil { return nil, err } @@ -462,9 +464,10 @@ func (hs *serverHandshakeState) doResumeHandshake() error { hs.hello.ticketSupported = hs.sessionState.usedOldKey hs.finishedHash = newFinishedHash(c.vers, hs.suite) hs.finishedHash.discardHandshakeBuffer() - hs.finishedHash.Write(hs.clientHello.marshal()) - hs.finishedHash.Write(hs.hello.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { + if err := transcriptMsg(hs.clientHello, &hs.finishedHash); err != nil { + return err + } + if _, err := hs.c.writeHandshakeRecord(hs.hello, &hs.finishedHash); err != nil { return err } @@ -502,24 +505,23 @@ func (hs *serverHandshakeState) doFullHandshake() error { // certificates won't be used. hs.finishedHash.discardHandshakeBuffer() } - hs.finishedHash.Write(hs.clientHello.marshal()) - hs.finishedHash.Write(hs.hello.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { + if err := transcriptMsg(hs.clientHello, &hs.finishedHash); err != nil { + return err + } + if _, err := hs.c.writeHandshakeRecord(hs.hello, &hs.finishedHash); err != nil { return err } certMsg := new(certificateMsg) certMsg.certificates = hs.cert.Certificate - hs.finishedHash.Write(certMsg.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(certMsg, &hs.finishedHash); err != nil { return err } if hs.hello.ocspStapling { certStatus := new(certificateStatusMsg) certStatus.response = hs.cert.OCSPStaple - hs.finishedHash.Write(certStatus.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, certStatus.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(certStatus, &hs.finishedHash); err != nil { return err } } @@ -531,8 +533,7 @@ func (hs *serverHandshakeState) doFullHandshake() error { return err } if skx != nil { - hs.finishedHash.Write(skx.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, skx.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(skx, &hs.finishedHash); err != nil { return err } } @@ -558,15 +559,13 @@ func (hs *serverHandshakeState) doFullHandshake() error { if c.config.ClientCAs != nil { certReq.certificateAuthorities = c.config.ClientCAs.Subjects() } - hs.finishedHash.Write(certReq.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, certReq.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(certReq, &hs.finishedHash); err != nil { return err } } helloDone := new(serverHelloDoneMsg) - hs.finishedHash.Write(helloDone.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, helloDone.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(helloDone, &hs.finishedHash); err != nil { return err } @@ -576,7 +575,7 @@ func (hs *serverHandshakeState) doFullHandshake() error { var pub crypto.PublicKey // public key for client auth, if any - msg, err := c.readHandshake() + msg, err := c.readHandshake(&hs.finishedHash) if err != nil { return err } @@ -589,7 +588,6 @@ func (hs *serverHandshakeState) doFullHandshake() error { c.sendAlert(alertUnexpectedMessage) return unexpectedMessageError(certMsg, msg) } - hs.finishedHash.Write(certMsg.marshal()) if err := c.processCertsFromClient(Certificate{ Certificate: certMsg.certificates, @@ -600,7 +598,7 @@ func (hs *serverHandshakeState) doFullHandshake() error { pub = c.peerCertificates[0].PublicKey } - msg, err = c.readHandshake() + msg, err = c.readHandshake(&hs.finishedHash) if err != nil { return err } @@ -618,7 +616,6 @@ func (hs *serverHandshakeState) doFullHandshake() error { c.sendAlert(alertUnexpectedMessage) return unexpectedMessageError(ckx, msg) } - hs.finishedHash.Write(ckx.marshal()) preMasterSecret, err := keyAgreement.processClientKeyExchange(c.config, hs.cert, ckx, c.vers) if err != nil { @@ -638,7 +635,10 @@ func (hs *serverHandshakeState) doFullHandshake() error { // to the client's certificate. This allows us to verify that the client is in // possession of the private key of the certificate. if len(c.peerCertificates) > 0 { - msg, err = c.readHandshake() + // certificateVerifyMsg is included in the transcript, but not until + // after we verify the handshake signature, since the state before + // this message was sent is used. + msg, err = c.readHandshake(nil) if err != nil { return err } @@ -673,7 +673,9 @@ func (hs *serverHandshakeState) doFullHandshake() error { return errors.New("tls: invalid signature by the client certificate: " + err.Error()) } - hs.finishedHash.Write(certVerify.marshal()) + if err := transcriptMsg(certVerify, &hs.finishedHash); err != nil { + return err + } } hs.finishedHash.discardHandshakeBuffer() @@ -713,7 +715,10 @@ func (hs *serverHandshakeState) readFinished(out []byte) error { return err } - msg, err := c.readHandshake() + // finishedMsg is included in the transcript, but not until after we + // check the client version, since the state before this message was + // sent is used during verification. + msg, err := c.readHandshake(nil) if err != nil { return err } @@ -730,7 +735,10 @@ func (hs *serverHandshakeState) readFinished(out []byte) error { return errors.New("tls: client's Finished message is incorrect") } - hs.finishedHash.Write(clientFinished.marshal()) + if err := transcriptMsg(clientFinished, &hs.finishedHash); err != nil { + return err + } + copy(out, verify) return nil } @@ -764,14 +772,16 @@ func (hs *serverHandshakeState) sendSessionTicket() error { masterSecret: hs.masterSecret, certificates: certsFromClient, } - var err error - m.ticket, err = c.encryptTicket(state.marshal()) + stateBytes, err := state.marshal() + if err != nil { + return err + } + m.ticket, err = c.encryptTicket(stateBytes) if err != nil { return err } - hs.finishedHash.Write(m.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, m.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(m, &hs.finishedHash); err != nil { return err } @@ -781,14 +791,13 @@ func (hs *serverHandshakeState) sendSessionTicket() error { func (hs *serverHandshakeState) sendFinished(out []byte) error { c := hs.c - if _, err := c.writeRecord(recordTypeChangeCipherSpec, []byte{1}); err != nil { + if err := c.writeChangeCipherRecord(); err != nil { return err } finished := new(finishedMsg) finished.verifyData = hs.finishedHash.serverSum(hs.masterSecret) - hs.finishedHash.Write(finished.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(finished, &hs.finishedHash); err != nil { return err } diff --git a/src/crypto/tls/handshake_server_test.go b/src/crypto/tls/handshake_server_test.go index 78889f4ad2746..04abdcca89040 100644 --- a/src/crypto/tls/handshake_server_test.go +++ b/src/crypto/tls/handshake_server_test.go @@ -30,6 +30,13 @@ func testClientHello(t *testing.T, serverConfig *Config, m handshakeMessage) { testClientHelloFailure(t, serverConfig, m, "") } +// testFatal is a hack to prevent the compiler from complaining that there is a +// call to t.Fatal from a non-test goroutine +func testFatal(t *testing.T, err error) { + t.Helper() + t.Fatal(err) +} + func testClientHelloFailure(t *testing.T, serverConfig *Config, m handshakeMessage, expectedSubStr string) { c, s := localPipe(t) go func() { @@ -37,7 +44,9 @@ func testClientHelloFailure(t *testing.T, serverConfig *Config, m handshakeMessa if ch, ok := m.(*clientHelloMsg); ok { cli.vers = ch.vers } - cli.writeRecord(recordTypeHandshake, m.marshal()) + if _, err := cli.writeHandshakeRecord(m, nil); err != nil { + testFatal(t, err) + } c.Close() }() ctx := context.Background() @@ -194,7 +203,9 @@ func TestRenegotiationExtension(t *testing.T) { go func() { cli := Client(c, testConfig) cli.vers = clientHello.vers - cli.writeRecord(recordTypeHandshake, clientHello.marshal()) + if _, err := cli.writeHandshakeRecord(clientHello, nil); err != nil { + testFatal(t, err) + } buf := make([]byte, 1024) n, err := c.Read(buf) @@ -253,8 +264,10 @@ func TestTLS12OnlyCipherSuites(t *testing.T) { go func() { cli := Client(c, testConfig) cli.vers = clientHello.vers - cli.writeRecord(recordTypeHandshake, clientHello.marshal()) - reply, err := cli.readHandshake() + if _, err := cli.writeHandshakeRecord(clientHello, nil); err != nil { + testFatal(t, err) + } + reply, err := cli.readHandshake(nil) c.Close() if err != nil { replyChan <- err @@ -311,8 +324,10 @@ func TestTLSPointFormats(t *testing.T) { go func() { cli := Client(c, testConfig) cli.vers = clientHello.vers - cli.writeRecord(recordTypeHandshake, clientHello.marshal()) - reply, err := cli.readHandshake() + if _, err := cli.writeHandshakeRecord(clientHello, nil); err != nil { + testFatal(t, err) + } + reply, err := cli.readHandshake(nil) c.Close() if err != nil { replyChan <- err @@ -1426,7 +1441,9 @@ func TestSNIGivenOnFailure(t *testing.T) { go func() { cli := Client(c, testConfig) cli.vers = clientHello.vers - cli.writeRecord(recordTypeHandshake, clientHello.marshal()) + if _, err := cli.writeHandshakeRecord(clientHello, nil); err != nil { + testFatal(t, err) + } c.Close() }() conn := Server(s, serverConfig) diff --git a/src/crypto/tls/handshake_server_tls13.go b/src/crypto/tls/handshake_server_tls13.go index 80d4dce3c5d6e..b7b568cd84ac8 100644 --- a/src/crypto/tls/handshake_server_tls13.go +++ b/src/crypto/tls/handshake_server_tls13.go @@ -306,7 +306,12 @@ func (hs *serverHandshakeStateTLS13) checkForResumption() error { c.sendAlert(alertInternalError) return errors.New("tls: internal error: failed to clone hash") } - transcript.Write(hs.clientHello.marshalWithoutBinders()) + clientHelloBytes, err := hs.clientHello.marshalWithoutBinders() + if err != nil { + c.sendAlert(alertInternalError) + return err + } + transcript.Write(clientHelloBytes) pskBinder := hs.suite.finishedHash(binderKey, transcript) if !hmac.Equal(hs.clientHello.pskBinders[i], pskBinder) { c.sendAlert(alertDecryptError) @@ -397,8 +402,7 @@ func (hs *serverHandshakeStateTLS13) sendDummyChangeCipherSpec() error { } hs.sentDummyCCS = true - _, err := hs.c.writeRecord(recordTypeChangeCipherSpec, []byte{1}) - return err + return hs.c.writeChangeCipherRecord() } func (hs *serverHandshakeStateTLS13) doHelloRetryRequest(selectedGroup CurveID) error { @@ -406,7 +410,9 @@ func (hs *serverHandshakeStateTLS13) doHelloRetryRequest(selectedGroup CurveID) // The first ClientHello gets double-hashed into the transcript upon a // HelloRetryRequest. See RFC 8446, Section 4.4.1. - hs.transcript.Write(hs.clientHello.marshal()) + if err := transcriptMsg(hs.clientHello, hs.transcript); err != nil { + return err + } chHash := hs.transcript.Sum(nil) hs.transcript.Reset() hs.transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))}) @@ -422,8 +428,7 @@ func (hs *serverHandshakeStateTLS13) doHelloRetryRequest(selectedGroup CurveID) selectedGroup: selectedGroup, } - hs.transcript.Write(helloRetryRequest.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, helloRetryRequest.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(helloRetryRequest, hs.transcript); err != nil { return err } @@ -431,7 +436,8 @@ func (hs *serverHandshakeStateTLS13) doHelloRetryRequest(selectedGroup CurveID) return err } - msg, err := c.readHandshake() + // clientHelloMsg is not included in the transcript. + msg, err := c.readHandshake(nil) if err != nil { return err } @@ -522,9 +528,10 @@ func illegalClientHelloChange(ch, ch1 *clientHelloMsg) bool { func (hs *serverHandshakeStateTLS13) sendServerParameters() error { c := hs.c - hs.transcript.Write(hs.clientHello.marshal()) - hs.transcript.Write(hs.hello.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { + if err := transcriptMsg(hs.clientHello, hs.transcript); err != nil { + return err + } + if _, err := hs.c.writeHandshakeRecord(hs.hello, hs.transcript); err != nil { return err } @@ -567,8 +574,7 @@ func (hs *serverHandshakeStateTLS13) sendServerParameters() error { encryptedExtensions.alpnProtocol = selectedProto c.clientProtocol = selectedProto - hs.transcript.Write(encryptedExtensions.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, encryptedExtensions.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(encryptedExtensions, hs.transcript); err != nil { return err } @@ -597,8 +603,7 @@ func (hs *serverHandshakeStateTLS13) sendServerCertificate() error { certReq.certificateAuthorities = c.config.ClientCAs.Subjects() } - hs.transcript.Write(certReq.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, certReq.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(certReq, hs.transcript); err != nil { return err } } @@ -609,8 +614,7 @@ func (hs *serverHandshakeStateTLS13) sendServerCertificate() error { certMsg.scts = hs.clientHello.scts && len(hs.cert.SignedCertificateTimestamps) > 0 certMsg.ocspStapling = hs.clientHello.ocspStapling && len(hs.cert.OCSPStaple) > 0 - hs.transcript.Write(certMsg.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(certMsg, hs.transcript); err != nil { return err } @@ -641,8 +645,7 @@ func (hs *serverHandshakeStateTLS13) sendServerCertificate() error { } certVerifyMsg.signature = sig - hs.transcript.Write(certVerifyMsg.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, certVerifyMsg.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(certVerifyMsg, hs.transcript); err != nil { return err } @@ -656,8 +659,7 @@ func (hs *serverHandshakeStateTLS13) sendServerFinished() error { verifyData: hs.suite.finishedHash(c.out.trafficSecret, hs.transcript), } - hs.transcript.Write(finished.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil { + if _, err := hs.c.writeHandshakeRecord(finished, hs.transcript); err != nil { return err } @@ -718,7 +720,9 @@ func (hs *serverHandshakeStateTLS13) sendSessionTickets() error { finishedMsg := &finishedMsg{ verifyData: hs.clientFinished, } - hs.transcript.Write(finishedMsg.marshal()) + if err := transcriptMsg(finishedMsg, hs.transcript); err != nil { + return err + } if !hs.shouldSendSessionTickets() { return nil @@ -743,8 +747,12 @@ func (hs *serverHandshakeStateTLS13) sendSessionTickets() error { SignedCertificateTimestamps: c.scts, }, } - var err error - m.label, err = c.encryptTicket(state.marshal()) + stateBytes, err := state.marshal() + if err != nil { + c.sendAlert(alertInternalError) + return err + } + m.label, err = c.encryptTicket(stateBytes) if err != nil { return err } @@ -763,7 +771,7 @@ func (hs *serverHandshakeStateTLS13) sendSessionTickets() error { // ticket_nonce, which must be unique per connection, is always left at // zero because we only ever send one ticket per connection. - if _, err := c.writeRecord(recordTypeHandshake, m.marshal()); err != nil { + if _, err := c.writeHandshakeRecord(m, nil); err != nil { return err } @@ -788,7 +796,7 @@ func (hs *serverHandshakeStateTLS13) readClientCertificate() error { // If we requested a client certificate, then the client must send a // certificate message. If it's empty, no CertificateVerify is sent. - msg, err := c.readHandshake() + msg, err := c.readHandshake(hs.transcript) if err != nil { return err } @@ -798,7 +806,6 @@ func (hs *serverHandshakeStateTLS13) readClientCertificate() error { c.sendAlert(alertUnexpectedMessage) return unexpectedMessageError(certMsg, msg) } - hs.transcript.Write(certMsg.marshal()) if err := c.processCertsFromClient(certMsg.certificate); err != nil { return err @@ -812,7 +819,10 @@ func (hs *serverHandshakeStateTLS13) readClientCertificate() error { } if len(certMsg.certificate.Certificate) != 0 { - msg, err = c.readHandshake() + // certificateVerifyMsg is included in the transcript, but not until + // after we verify the handshake signature, since the state before + // this message was sent is used. + msg, err = c.readHandshake(nil) if err != nil { return err } @@ -843,7 +853,9 @@ func (hs *serverHandshakeStateTLS13) readClientCertificate() error { return errors.New("tls: invalid signature by the client certificate: " + err.Error()) } - hs.transcript.Write(certVerify.marshal()) + if err := transcriptMsg(certVerify, hs.transcript); err != nil { + return err + } } // If we waited until the client certificates to send session tickets, we @@ -858,7 +870,8 @@ func (hs *serverHandshakeStateTLS13) readClientCertificate() error { func (hs *serverHandshakeStateTLS13) readClientFinished() error { c := hs.c - msg, err := c.readHandshake() + // finishedMsg is not included in the transcript. + msg, err := c.readHandshake(nil) if err != nil { return err } diff --git a/src/crypto/tls/key_schedule.go b/src/crypto/tls/key_schedule.go index 8150d804a4043..ae8f80a7cfcc5 100644 --- a/src/crypto/tls/key_schedule.go +++ b/src/crypto/tls/key_schedule.go @@ -8,6 +8,7 @@ import ( "crypto/ecdh" "crypto/hmac" "errors" + "fmt" "hash" "io" @@ -40,8 +41,24 @@ func (c *cipherSuiteTLS13) expandLabel(secret []byte, label string, context []by hkdfLabel.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { b.AddBytes(context) }) + hkdfLabelBytes, err := hkdfLabel.Bytes() + if err != nil { + // Rather than calling BytesOrPanic, we explicitly handle this error, in + // order to provide a reasonable error message. It should be basically + // impossible for this to panic, and routing errors back through the + // tree rooted in this function is quite painful. The labels are fixed + // size, and the context is either a fixed-length computed hash, or + // parsed from a field which has the same length limitation. As such, an + // error here is likely to only be caused during development. + // + // NOTE: another reasonable approach here might be to return a + // randomized slice if we encounter an error, which would break the + // connection, but avoid panicking. This would perhaps be safer but + // significantly more confusing to users. + panic(fmt.Errorf("failed to construct HKDF label: %s", err)) + } out := make([]byte, length) - n, err := hkdf.Expand(c.hash.New, secret, hkdfLabel.BytesOrPanic()).Read(out) + n, err := hkdf.Expand(c.hash.New, secret, hkdfLabelBytes).Read(out) if err != nil || n != length { panic("tls: HKDF-Expand-Label invocation failed unexpectedly") } diff --git a/src/crypto/tls/ticket.go b/src/crypto/tls/ticket.go index 6c1d20da206da..b82ccd141e791 100644 --- a/src/crypto/tls/ticket.go +++ b/src/crypto/tls/ticket.go @@ -32,7 +32,7 @@ type sessionState struct { usedOldKey bool } -func (m *sessionState) marshal() []byte { +func (m *sessionState) marshal() ([]byte, error) { var b cryptobyte.Builder b.AddUint16(m.vers) b.AddUint16(m.cipherSuite) @@ -47,7 +47,7 @@ func (m *sessionState) marshal() []byte { }) } }) - return b.BytesOrPanic() + return b.Bytes() } func (m *sessionState) unmarshal(data []byte) bool { @@ -86,7 +86,7 @@ type sessionStateTLS13 struct { certificate Certificate // CertificateEntry certificate_list<0..2^24-1>; } -func (m *sessionStateTLS13) marshal() []byte { +func (m *sessionStateTLS13) marshal() ([]byte, error) { var b cryptobyte.Builder b.AddUint16(VersionTLS13) b.AddUint8(0) // revision @@ -96,7 +96,7 @@ func (m *sessionStateTLS13) marshal() []byte { b.AddBytes(m.resumptionSecret) }) marshalCertificate(&b, m.certificate) - return b.BytesOrPanic() + return b.Bytes() } func (m *sessionStateTLS13) unmarshal(data []byte) bool { From 8e02cffd8e8a1d5d7b25bd46f675fc8ff9e841d0 Mon Sep 17 00:00:00 2001 From: Roland Shoemaker Date: Mon, 6 Feb 2023 10:09:00 -0800 Subject: [PATCH 15/17] [release-branch.go1.20] net/http: update bundled golang.org/x/net/http2 Disable cmd/internal/moddeps test, since this update includes PRIVATE track fixes. Fixes CVE-2022-41723 Fixes #58356 Updates #57855 Change-Id: I603886b5b76c16303dab1420d4ec8b7c7cdcf330 Reviewed-on: https://team-review.git.corp.google.com/c/golang/go-private/+/1728940 Reviewed-by: Damien Neil Reviewed-by: Julie Qiu TryBot-Result: Security TryBots Reviewed-by: Tatiana Bradley Run-TryBot: Roland Shoemaker Reviewed-on: https://go-review.googlesource.com/c/go/+/468122 Auto-Submit: Michael Pratt TryBot-Result: Gopher Robot Run-TryBot: Michael Pratt Reviewed-by: Than McIntosh --- src/cmd/internal/moddeps/moddeps_test.go | 2 + .../golang.org/x/net/http2/hpack/hpack.go | 79 ++++++++++++------- 2 files changed, 51 insertions(+), 30 deletions(-) diff --git a/src/cmd/internal/moddeps/moddeps_test.go b/src/cmd/internal/moddeps/moddeps_test.go index 41220645c6ed1..25125face050e 100644 --- a/src/cmd/internal/moddeps/moddeps_test.go +++ b/src/cmd/internal/moddeps/moddeps_test.go @@ -31,6 +31,8 @@ import ( // See issues 36852, 41409, and 43687. // (Also see golang.org/issue/27348.) func TestAllDependencies(t *testing.T) { + t.Skip("TODO(#58356): 1.19.4 contains unreleased changes from vendored modules") + goBin := testenv.GoToolPath(t) // Ensure that all packages imported within GOROOT diff --git a/src/vendor/golang.org/x/net/http2/hpack/hpack.go b/src/vendor/golang.org/x/net/http2/hpack/hpack.go index ebdfbee964ae3..fe52df95e8cda 100644 --- a/src/vendor/golang.org/x/net/http2/hpack/hpack.go +++ b/src/vendor/golang.org/x/net/http2/hpack/hpack.go @@ -359,6 +359,7 @@ func (d *Decoder) parseFieldLiteral(n uint8, it indexType) error { var hf HeaderField wantStr := d.emitEnabled || it.indexed() + var undecodedName undecodedString if nameIdx > 0 { ihf, ok := d.at(nameIdx) if !ok { @@ -366,15 +367,27 @@ func (d *Decoder) parseFieldLiteral(n uint8, it indexType) error { } hf.Name = ihf.Name } else { - hf.Name, buf, err = d.readString(buf, wantStr) + undecodedName, buf, err = d.readString(buf) if err != nil { return err } } - hf.Value, buf, err = d.readString(buf, wantStr) + undecodedValue, buf, err := d.readString(buf) if err != nil { return err } + if wantStr { + if nameIdx <= 0 { + hf.Name, err = d.decodeString(undecodedName) + if err != nil { + return err + } + } + hf.Value, err = d.decodeString(undecodedValue) + if err != nil { + return err + } + } d.buf = buf if it.indexed() { d.dynTab.add(hf) @@ -459,46 +472,52 @@ func readVarInt(n byte, p []byte) (i uint64, remain []byte, err error) { return 0, origP, errNeedMore } -// readString decodes an hpack string from p. +// readString reads an hpack string from p. // -// wantStr is whether s will be used. If false, decompression and -// []byte->string garbage are skipped if s will be ignored -// anyway. This does mean that huffman decoding errors for non-indexed -// strings past the MAX_HEADER_LIST_SIZE are ignored, but the server -// is returning an error anyway, and because they're not indexed, the error -// won't affect the decoding state. -func (d *Decoder) readString(p []byte, wantStr bool) (s string, remain []byte, err error) { +// It returns a reference to the encoded string data to permit deferring decode costs +// until after the caller verifies all data is present. +func (d *Decoder) readString(p []byte) (u undecodedString, remain []byte, err error) { if len(p) == 0 { - return "", p, errNeedMore + return u, p, errNeedMore } isHuff := p[0]&128 != 0 strLen, p, err := readVarInt(7, p) if err != nil { - return "", p, err + return u, p, err } if d.maxStrLen != 0 && strLen > uint64(d.maxStrLen) { - return "", nil, ErrStringLength + // Returning an error here means Huffman decoding errors + // for non-indexed strings past the maximum string length + // are ignored, but the server is returning an error anyway + // and because the string is not indexed the error will not + // affect the decoding state. + return u, nil, ErrStringLength } if uint64(len(p)) < strLen { - return "", p, errNeedMore - } - if !isHuff { - if wantStr { - s = string(p[:strLen]) - } - return s, p[strLen:], nil + return u, p, errNeedMore } + u.isHuff = isHuff + u.b = p[:strLen] + return u, p[strLen:], nil +} - if wantStr { - buf := bufPool.Get().(*bytes.Buffer) - buf.Reset() // don't trust others - defer bufPool.Put(buf) - if err := huffmanDecode(buf, d.maxStrLen, p[:strLen]); err != nil { - buf.Reset() - return "", nil, err - } +type undecodedString struct { + isHuff bool + b []byte +} + +func (d *Decoder) decodeString(u undecodedString) (string, error) { + if !u.isHuff { + return string(u.b), nil + } + buf := bufPool.Get().(*bytes.Buffer) + buf.Reset() // don't trust others + var s string + err := huffmanDecode(buf, d.maxStrLen, u.b) + if err == nil { s = buf.String() - buf.Reset() // be nice to GC } - return s, p[strLen:], nil + buf.Reset() // be nice to GC + bufPool.Put(buf) + return s, err } From 202a1a57064127c3f19d96df57b9f9586145e21c Mon Sep 17 00:00:00 2001 From: Gopher Robot Date: Tue, 14 Feb 2023 17:53:38 +0000 Subject: [PATCH 16/17] [release-branch.go1.20] go1.20.1 Change-Id: I6a40cdd44d7bc7e4bf95a5169ecad16757eb41d3 Reviewed-on: https://go-review.googlesource.com/c/go/+/468238 Auto-Submit: Gopher Robot Reviewed-by: Michael Pratt Run-TryBot: Gopher Robot Reviewed-by: Than McIntosh TryBot-Result: Gopher Robot --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index 83534e24796a8..866106008f2eb 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -go1.20 \ No newline at end of file +go1.20.1 \ No newline at end of file From 828b05cc647e9777c7e8c67fdd9d5bef2b842d31 Mon Sep 17 00:00:00 2001 From: Michael Pratt Date: Tue, 14 Feb 2023 15:15:23 -0500 Subject: [PATCH 17/17] [release-branch.go1.20] all: update vendored golang.org/x/net Update golang.org/x/net to the tip of internal-branch.go1.20-vendor to include CL 468336. The contents of that CL were already merged into this branch in CL 468122, so this CL just brings go.mod back in line to matching the actual vendored content. For #58356 For #57855 Change-Id: I6ee9483077630c11c725927f38f6b69a784106db Reviewed-on: https://go-review.googlesource.com/c/go/+/468302 Run-TryBot: Michael Pratt TryBot-Result: Gopher Robot Reviewed-by: Than McIntosh Auto-Submit: Michael Pratt --- src/cmd/internal/moddeps/moddeps_test.go | 2 -- src/go.mod | 2 +- src/go.sum | 4 ++-- src/vendor/modules.txt | 2 +- 4 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/cmd/internal/moddeps/moddeps_test.go b/src/cmd/internal/moddeps/moddeps_test.go index 25125face050e..41220645c6ed1 100644 --- a/src/cmd/internal/moddeps/moddeps_test.go +++ b/src/cmd/internal/moddeps/moddeps_test.go @@ -31,8 +31,6 @@ import ( // See issues 36852, 41409, and 43687. // (Also see golang.org/issue/27348.) func TestAllDependencies(t *testing.T) { - t.Skip("TODO(#58356): 1.19.4 contains unreleased changes from vendored modules") - goBin := testenv.GoToolPath(t) // Ensure that all packages imported within GOROOT diff --git a/src/go.mod b/src/go.mod index 2a1261f925a84..4697da201c0b4 100644 --- a/src/go.mod +++ b/src/go.mod @@ -4,7 +4,7 @@ go 1.20 require ( golang.org/x/crypto v0.3.1-0.20221117191849-2c476679df9a - golang.org/x/net v0.3.1-0.20221206200815-1e63c2f08a10 + golang.org/x/net v0.4.1-0.20230214201333-88ed8ca3307d ) require ( diff --git a/src/go.sum b/src/go.sum index ef6748d5968c2..625f2070b3487 100644 --- a/src/go.sum +++ b/src/go.sum @@ -1,7 +1,7 @@ golang.org/x/crypto v0.3.1-0.20221117191849-2c476679df9a h1:diz9pEYuTIuLMJLs3rGDkeaTsNyRs6duYdFyPAxzE/U= golang.org/x/crypto v0.3.1-0.20221117191849-2c476679df9a/go.mod h1:hebNnKkNXi2UzZN1eVRvBB7co0a+JxK6XbPiWVs/3J4= -golang.org/x/net v0.3.1-0.20221206200815-1e63c2f08a10 h1:Frnccbp+ok2GkUS2tC84yAq/U9Vg+0sIO7aRL3T4Xnc= -golang.org/x/net v0.3.1-0.20221206200815-1e63c2f08a10/go.mod h1:MBQ8lrhLObU/6UmLb4fmbmk5OcyYmqtbGd/9yIeKjEE= +golang.org/x/net v0.4.1-0.20230214201333-88ed8ca3307d h1:KHU/KRz6+/yWyRHEC24m7T5gou5VSh62duch955ktBY= +golang.org/x/net v0.4.1-0.20230214201333-88ed8ca3307d/go.mod h1:MBQ8lrhLObU/6UmLb4fmbmk5OcyYmqtbGd/9yIeKjEE= golang.org/x/sys v0.3.0 h1:w8ZOecv6NaNa/zC8944JTU3vz4u6Lagfk4RPQxv92NQ= golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/text v0.5.0 h1:OLmvp0KP+FVG99Ct/qFiL/Fhk4zp4QQnZ7b2U+5piUM= diff --git a/src/vendor/modules.txt b/src/vendor/modules.txt index 3e4bb5b90bc67..89a7c86c41dbd 100644 --- a/src/vendor/modules.txt +++ b/src/vendor/modules.txt @@ -7,7 +7,7 @@ golang.org/x/crypto/cryptobyte/asn1 golang.org/x/crypto/hkdf golang.org/x/crypto/internal/alias golang.org/x/crypto/internal/poly1305 -# golang.org/x/net v0.3.1-0.20221206200815-1e63c2f08a10 +# golang.org/x/net v0.4.1-0.20230214201333-88ed8ca3307d ## explicit; go 1.17 golang.org/x/net/dns/dnsmessage golang.org/x/net/http/httpguts