Skip to content

Commit 928e3d9

Browse files
committed
encoding/xml: call MarshalXML(), MarshalXMLAttr(), and MarshalText() defined with pointer receivers even for non-addressable values of non-pointer types on marshalling XML
1 parent f0d880e commit 928e3d9

File tree

2 files changed

+140
-28
lines changed

2 files changed

+140
-28
lines changed

src/encoding/xml/marshal.go

+33-28
Original file line numberDiff line numberDiff line change
@@ -451,22 +451,25 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo, startTemplat
451451
if val.CanInterface() && typ.Implements(marshalerType) {
452452
return p.marshalInterface(val.Interface().(Marshaler), defaultStart(typ, finfo, startTemplate))
453453
}
454+
455+
var pv reflect.Value
454456
if val.CanAddr() {
455-
pv := val.Addr()
456-
if pv.CanInterface() && pv.Type().Implements(marshalerType) {
457-
return p.marshalInterface(pv.Interface().(Marshaler), defaultStart(pv.Type(), finfo, startTemplate))
458-
}
457+
pv = val.Addr()
458+
} else {
459+
pv = reflect.New(typ)
460+
pv.Elem().Set(val)
461+
}
462+
463+
if pv.CanInterface() && pv.Type().Implements(marshalerType) {
464+
return p.marshalInterface(pv.Interface().(Marshaler), defaultStart(pv.Type(), finfo, startTemplate))
459465
}
460466

461467
// Check for text marshaler.
462468
if val.CanInterface() && typ.Implements(textMarshalerType) {
463469
return p.marshalTextInterface(val.Interface().(encoding.TextMarshaler), defaultStart(typ, finfo, startTemplate))
464470
}
465-
if val.CanAddr() {
466-
pv := val.Addr()
467-
if pv.CanInterface() && pv.Type().Implements(textMarshalerType) {
468-
return p.marshalTextInterface(pv.Interface().(encoding.TextMarshaler), defaultStart(pv.Type(), finfo, startTemplate))
469-
}
471+
if pv.CanInterface() && pv.Type().Implements(textMarshalerType) {
472+
return p.marshalTextInterface(pv.Interface().(encoding.TextMarshaler), defaultStart(pv.Type(), finfo, startTemplate))
470473
}
471474

472475
// Slices and arrays iterate over the elements. They do not have an enclosing tag.
@@ -589,18 +592,23 @@ func (p *printer) marshalAttr(start *StartElement, name Name, val reflect.Value)
589592
return nil
590593
}
591594

595+
var pv reflect.Value
592596
if val.CanAddr() {
593-
pv := val.Addr()
594-
if pv.CanInterface() && pv.Type().Implements(marshalerAttrType) {
595-
attr, err := pv.Interface().(MarshalerAttr).MarshalXMLAttr(name)
596-
if err != nil {
597-
return err
598-
}
599-
if attr.Name.Local != "" {
600-
start.Attr = append(start.Attr, attr)
601-
}
602-
return nil
597+
pv = val.Addr()
598+
} else {
599+
pv = reflect.New(val.Type())
600+
pv.Elem().Set(val)
601+
}
602+
603+
if pv.CanInterface() && pv.Type().Implements(marshalerAttrType) {
604+
attr, err := pv.Interface().(MarshalerAttr).MarshalXMLAttr(name)
605+
if err != nil {
606+
return err
603607
}
608+
if attr.Name.Local != "" {
609+
start.Attr = append(start.Attr, attr)
610+
}
611+
return nil
604612
}
605613

606614
if val.CanInterface() && val.Type().Implements(textMarshalerType) {
@@ -612,16 +620,13 @@ func (p *printer) marshalAttr(start *StartElement, name Name, val reflect.Value)
612620
return nil
613621
}
614622

615-
if val.CanAddr() {
616-
pv := val.Addr()
617-
if pv.CanInterface() && pv.Type().Implements(textMarshalerType) {
618-
text, err := pv.Interface().(encoding.TextMarshaler).MarshalText()
619-
if err != nil {
620-
return err
621-
}
622-
start.Attr = append(start.Attr, Attr{name, string(text)})
623-
return nil
623+
if pv.CanInterface() && pv.Type().Implements(textMarshalerType) {
624+
text, err := pv.Interface().(encoding.TextMarshaler).MarshalText()
625+
if err != nil {
626+
return err
624627
}
628+
start.Attr = append(start.Attr, Attr{name, string(text)})
629+
return nil
625630
}
626631

627632
// Dereference or skip nil pointer, interface values.

src/encoding/xml/marshal_test.go

+107
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ package xml
66

77
import (
88
"bytes"
9+
"encoding"
910
"errors"
1011
"fmt"
1112
"io"
@@ -2589,3 +2590,109 @@ func TestClose(t *testing.T) {
25892590
})
25902591
}
25912592
}
2593+
2594+
type structWithMarshalXML struct{ V int }
2595+
2596+
func (s *structWithMarshalXML) MarshalXML(e *Encoder, _ StartElement) error {
2597+
_ = e.EncodeToken(StartElement{Name: Name{Local: "marshalled"}})
2598+
_ = e.EncodeToken(CharData(strconv.Itoa(s.V)))
2599+
_ = e.EncodeToken(EndElement{Name: Name{Local: "marshalled"}})
2600+
return nil
2601+
}
2602+
2603+
var _ = Marshaler(&structWithMarshalXML{})
2604+
2605+
type embedderX struct {
2606+
V structWithMarshalXML
2607+
}
2608+
2609+
func TestMarshalXMLWithPointerXMLMarshalers(t *testing.T) {
2610+
for _, test := range []struct {
2611+
name string
2612+
v interface{}
2613+
expected string
2614+
}{
2615+
{name: "a value with MarshalXML", v: structWithMarshalXML{V: 1}, expected: `<marshalled>1</marshalled>`},
2616+
{name: "pointer to a value with MarshalXML", v: &structWithMarshalXML{V: 1}, expected: "<marshalled>1</marshalled>"},
2617+
{name: "a struct with a value with MarshalXML", v: embedderX{V: structWithMarshalXML{V: 1}}, expected: "<embedderX><marshalled>1</marshalled></embedderX>"},
2618+
{name: "a slice of structs with a value with MarshalXML", v: []embedderX{{V: structWithMarshalXML{V: 1}}}, expected: `<embedderX><marshalled>1</marshalled></embedderX>`},
2619+
} {
2620+
test := test
2621+
t.Run(test.name, func(t *testing.T) {
2622+
result, err := Marshal(test.v)
2623+
if err != nil {
2624+
t.Fatalf("Marshal error: %v", err)
2625+
}
2626+
if string(result) != test.expected {
2627+
t.Errorf("Marshal:\n\tgot: %s\n\twant: %s", result, test.expected)
2628+
}
2629+
})
2630+
}
2631+
}
2632+
2633+
type structWithMarshalText struct{ V int }
2634+
2635+
func (s *structWithMarshalText) MarshalText() ([]byte, error) {
2636+
return []byte(fmt.Sprintf("marshalled(%d)", s.V)), nil
2637+
}
2638+
2639+
var _ = encoding.TextMarshaler(&structWithMarshalText{})
2640+
2641+
type embedderT struct {
2642+
V structWithMarshalText
2643+
}
2644+
2645+
func TestMarshalXMLWithPointerTextMarshalers(t *testing.T) {
2646+
for _, test := range []struct {
2647+
name string
2648+
v interface{}
2649+
expected string
2650+
}{
2651+
{name: "a value with MarshalText", v: structWithMarshalText{V: 1}, expected: "<structWithMarshalText>marshalled(1)</structWithMarshalText>"},
2652+
{name: "pointer to a value with MarshalText", v: &structWithMarshalText{V: 1}, expected: "<structWithMarshalText>marshalled(1)</structWithMarshalText>"},
2653+
{name: "a struct with a value with MarshalText", v: embedderT{V: structWithMarshalText{V: 1}}, expected: "<embedderT><V>marshalled(1)</V></embedderT>"},
2654+
{name: "a slice of structs with a value with MarshalText", v: []embedderT{{V: structWithMarshalText{V: 1}}}, expected: "<embedderT><V>marshalled(1)</V></embedderT>"},
2655+
} {
2656+
test := test
2657+
t.Run(test.name, func(t *testing.T) {
2658+
result, err := Marshal(test.v)
2659+
if err != nil {
2660+
t.Fatalf("Marshal error: %v", err)
2661+
}
2662+
if string(result) != test.expected {
2663+
t.Errorf("Marshal:\n\tgot: %s\n\twant: %s", result, test.expected)
2664+
}
2665+
})
2666+
}
2667+
}
2668+
2669+
type structWithMarshalXMLAttr struct{ v int }
2670+
2671+
func (s *structWithMarshalXMLAttr) MarshalXMLAttr(name Name) (Attr, error) {
2672+
return Attr{Name: Name{Local: "marshalled"}, Value: strconv.Itoa(s.v)}, nil
2673+
}
2674+
2675+
var _ = MarshalerAttr(&structWithMarshalXMLAttr{})
2676+
2677+
type embedderAT struct {
2678+
X structWithMarshalXMLAttr `xml:"X,attr"`
2679+
T structWithMarshalText `xml:"T,attr"`
2680+
XP *structWithMarshalXMLAttr `xml:"XP,attr"`
2681+
XT *structWithMarshalText `xml:"XT,attr"`
2682+
}
2683+
2684+
func TestMarshalXMLWithPointerAttrMarshalers(t *testing.T) {
2685+
result, err := Marshal(embedderAT{
2686+
X: structWithMarshalXMLAttr{1},
2687+
T: structWithMarshalText{2},
2688+
XP: &structWithMarshalXMLAttr{3},
2689+
XT: &structWithMarshalText{4},
2690+
})
2691+
if err != nil {
2692+
t.Fatalf("Marshal error: %v", err)
2693+
}
2694+
expected := `<embedderAT marshalled="1" T="marshalled(2)" marshalled="3" XT="marshalled(4)"></embedderAT>`
2695+
if string(result) != expected {
2696+
t.Errorf("Marshal:\n\tgot: %s\n\twant: %s", result, expected)
2697+
}
2698+
}

0 commit comments

Comments
 (0)