diff options
author | Samuel Giddins <[email protected]> | 2023-08-18 13:35:23 -0700 |
---|---|---|
committer | git <[email protected]> | 2023-09-20 02:02:58 +0000 |
commit | d182d83ce929cd322f4a6fd134cd31be950eca77 (patch) | |
tree | 18808133999aab348990d35cae6920d0252dfa56 | |
parent | c47608494f961d2a8fe24b1a7b7f627b305cf7fe (diff) |
[rubygems/rubygems] Add a Marshal.load replacement that walks an AST to safely load permitted classes/symbols
https://github.com/rubygems/rubygems/commit/7e4478fe73
-rw-r--r-- | lib/rubygems.rb | 10 | ||||
-rw-r--r-- | lib/rubygems/indexer.rb | 3 | ||||
-rw-r--r-- | lib/rubygems/safe_marshal.rb | 71 | ||||
-rw-r--r-- | lib/rubygems/safe_marshal/elements.rb | 138 | ||||
-rw-r--r-- | lib/rubygems/safe_marshal/reader.rb | 182 | ||||
-rw-r--r-- | lib/rubygems/safe_marshal/visitors/to_ruby.rb | 266 | ||||
-rw-r--r-- | lib/rubygems/safe_marshal/visitors/visitor.rb | 74 | ||||
-rw-r--r-- | lib/rubygems/source.rb | 9 | ||||
-rw-r--r-- | lib/rubygems/specification.rb | 3 | ||||
-rw-r--r-- | test/rubygems/test_gem_safe_marshal.rb | 144 |
10 files changed, 895 insertions, 5 deletions
diff --git a/lib/rubygems.rb b/lib/rubygems.rb index cb71657018..557d7aa8eb 100644 --- a/lib/rubygems.rb +++ b/lib/rubygems.rb @@ -604,6 +604,16 @@ An Array (#{env.inspect}) was passed in from #{caller[3]} @yaml_loaded = true end + @safe_marshal_loaded = false + + def self.load_safe_marshal + return if @safe_marshal_loaded + + require_relative "rubygems/safe_marshal" + + @safe_marshal_loaded = true + end + ## # The file name and line number of the caller of the caller of this method. # diff --git a/lib/rubygems/indexer.rb b/lib/rubygems/indexer.rb index c6691517b3..f4c981b9ef 100644 --- a/lib/rubygems/indexer.rb +++ b/lib/rubygems/indexer.rb @@ -411,7 +411,8 @@ class Gem::Indexer # +dest+. For a latest index, does not ensure the new file is minimal. def update_specs_index(index, source, dest) - specs_index = Marshal.load Gem.read_binary(source) + Gem.load_safe_marshal + specs_index = Gem::SafeMarshal.safe_load Gem.read_binary(source) index.each do |spec| platform = spec.original_platform diff --git a/lib/rubygems/safe_marshal.rb b/lib/rubygems/safe_marshal.rb new file mode 100644 index 0000000000..172b92d5a6 --- /dev/null +++ b/lib/rubygems/safe_marshal.rb @@ -0,0 +1,71 @@ +# frozen_string_literal: true + +require_relative "safe_marshal/reader" +require_relative "safe_marshal/visitors/to_ruby" + +module Gem + ### + # This module is used for safely loading Marshal specs from a gem. The + # `safe_load` method defined on this module is specifically designed for + # loading Gem specifications. + + module SafeMarshal + PERMITTED_CLASSES = %w[ + Time + Date + + Gem::Dependency + Gem::NameTuple + Gem::Platform + Gem::Requirement + Gem::Specification + Gem::Version + Gem::Version::Requirement + + YAML::Syck::DefaultKey + YAML::PrivateType + ].freeze + private_constant :PERMITTED_CLASSES + + PERMITTED_SYMBOLS = %w[ + E + + offset + zone + nano_num + nano_den + submicro + + @_zone + @cpu + @force_ruby_platform + @marshal_with_utc_coercion + @name + @os + @platform + @prerelease + @requirement + @taguri + @type + @type_id + @value + @version + @version_requirement + @version_requirements + + development + runtime + ].freeze + private_constant :PERMITTED_SYMBOLS + + def self.safe_load(input) + load(input, permitted_classes: PERMITTED_CLASSES, permitted_symbols: PERMITTED_SYMBOLS) + end + + def self.load(input, permitted_classes: [::Symbol], permitted_symbols: []) + root = Reader.new(StringIO.new(input, "r")).read! + + Visitors::ToRuby.new(permitted_classes: permitted_classes, permitted_symbols: permitted_symbols).visit(root) + end + end +end diff --git a/lib/rubygems/safe_marshal/elements.rb b/lib/rubygems/safe_marshal/elements.rb new file mode 100644 index 0000000000..70961c40ff --- /dev/null +++ b/lib/rubygems/safe_marshal/elements.rb @@ -0,0 +1,138 @@ +# frozen_string_literal: true + +module Gem + module SafeMarshal + module Elements + class Element + end + + class Symbol < Element + def initialize(name:) + @name = name + end + attr_reader :name + end + + class UserDefined < Element + def initialize(name:, binary_string:) + @name = name + @binary_string = binary_string + end + + attr_reader :name, :binary_string + end + + class UserMarshal < Element + def initialize(name:, data:) + @name = name + @data = data + end + + attr_reader :name, :data + end + + class String < Element + def initialize(str:) + @str = str + end + + attr_reader :str + end + + class Hash < Element + def initialize(pairs:) + @pairs = pairs + end + + attr_reader :pairs + end + + class HashWithDefaultValue < Hash + def initialize(default:, **kwargs) + super(**kwargs) + @default = default + end + + attr_reader :default + end + + class Array < Element + def initialize(elements:) + @elements = elements + end + + attr_reader :elements + end + + class Integer < Element + def initialize(int:) + @int = int + end + + attr_reader :int + end + + class True < Element + def initialize + end + TRUE = new.freeze + end + + class False < Element + def initialize + end + + FALSE = new.freeze + end + + class WithIvars < Element + def initialize(object:,ivars:) + @object = object + @ivars = ivars + end + + attr_reader :object, :ivars + end + + class Object < Element + def initialize(name:) + @name = name + end + attr_reader :name + end + + class Nil < Element + NIL = new.freeze + end + + class ObjectLink < Element + def initialize(offset:) + @offset = offset + end + attr_reader :offset + end + + class SymbolLink < Element + def initialize(offset:) + @offset = offset + end + attr_reader :offset + end + + class Float < Element + def initialize(string:) + @string = string + end + attr_reader :string + end + + class Bignum < Element # rubocop:disable Lint/UnifiedInteger + def initialize(sign:, data:) + @sign = sign + @data = data + end + attr_reader :sign, :data + end + end + end +end diff --git a/lib/rubygems/safe_marshal/reader.rb b/lib/rubygems/safe_marshal/reader.rb new file mode 100644 index 0000000000..105984ff04 --- /dev/null +++ b/lib/rubygems/safe_marshal/reader.rb @@ -0,0 +1,182 @@ +# frozen_string_literal: true + +require_relative "elements" + +module Gem + module SafeMarshal + class Reader + class UnconsumedBytesError < StandardError + end + + def initialize(io) + @io = io + end + + def read! + read_header + root = read_element + raise UnconsumedBytesError unless @io.eof? + root + end + + private + + MARSHAL_VERSION = [Marshal::MAJOR_VERSION, Marshal::MINOR_VERSION].map(&:chr).join.freeze + private_constant :MARSHAL_VERSION + + def read_header + v = @io.read(2) + raise "Unsupported marshal version #{v.inspect}, expected #{MARSHAL_VERSION.inspect}" unless v == MARSHAL_VERSION + end + + def read_byte + @io.getbyte + end + + def read_integer + b = read_byte + + case b + when 0x00 + 0 + when 0x01 + @io.read(1).unpack1("C") + when 0x02 + @io.read(2).unpack1("S<") + when 0x03 + (@io.read(3) + "\0").unpack1("L<") + when 0x04 + @io.read(4).unpack1("L<") + when 0xFC + @io.read(4).unpack1("L<") | -0x100000000 + when 0xFD + (@io.read(3) + "\0").unpack1("L<") | -0x1000000 + when 0xFE + @io.read(2).unpack1("s<") | -0x10000 + when 0xFF + read_byte | -0x100 + else + signed = (b ^ 128) - 128 + if b >= 128 + signed + 5 + else + signed - 5 + end + end + end + + def read_element + type = read_byte + case type + when 34 then read_string # ?" + when 48 then read_nil # ?0 + when 58 then read_symbol # ?: + when 59 then read_symbol_link # ?; + when 64 then read_object_link # ?@ + when 70 then read_false # ?F + when 73 then read_object_with_ivars # ?I + when 84 then read_true # ?T + when 85 then read_user_marshal # ?U + when 91 then read_array # ?[ + when 102 then read_float # ?f + when 105 then Elements::Integer.new int: read_integer # ?i + when 108 then read_bignum + when 111 then read_object # ?o + when 117 then read_user_defined # ?u + when 123 then read_hash # ?{ + when 125 then read_hash_with_default_value # ?} + when "e".ord then read_extended_object + when "c".ord then read_class + when "m".ord then read_module + when "M".ord then read_class_or_module + when "d".ord then read_data + when "/".ord then read_regexp + when "S".ord then read_struct + when "C".ord then read_user_class + else + raise "Unsupported marshal type discriminator #{type.chr.inspect} (#{type})" + end + end + + def read_symbol + Elements::Symbol.new name: @io.read(read_integer) + end + + def read_string + Elements::String.new(str: @io.read(read_integer)) + end + + def read_true + Elements::True::TRUE + end + + def read_false + Elements::False::FALSE + end + + def read_user_defined + Elements::UserDefined.new(name: read_element, binary_string: @io.read(read_integer)) + end + + def read_array + Elements::Array.new(elements: Array.new(read_integer) do |_i| + read_element + end) + end + + def read_object_with_ivars + Elements::WithIvars.new(object: read_element, ivars: + Array.new(read_integer) do + [read_element, read_element] + end) + end + + def read_symbol_link + Elements::SymbolLink.new offset: read_integer + end + + def read_user_marshal + Elements::UserMarshal.new(name: read_element, data: read_element) + end + + def read_object_link + Elements::ObjectLink.new(offset: read_integer) + end + + def read_hash + pairs = Array.new(read_integer) do + [read_element, read_element] + end + Elements::Hash.new(pairs: pairs) + end + + def read_hash_with_default_value + pairs = Array.new(read_integer) do + [read_element, read_element] + end + Elements::HashWithDefaultValue.new(pairs: pairs, default: read_element) + end + + def read_object + Elements::WithIvars.new( + object: Elements::Object.new(name: read_element), + ivars: Array.new(read_integer) do + [read_element, read_element] + end + ) + end + + def read_nil + Elements::Nil::NIL + end + + def read_float + Elements::Float.new string: @io.read(read_integer) + end + + def read_bignum + Elements::Bignum.new(sign: read_byte, data: @io.read(read_integer * 2)) + end + end + end +end diff --git a/lib/rubygems/safe_marshal/visitors/to_ruby.rb b/lib/rubygems/safe_marshal/visitors/to_ruby.rb new file mode 100644 index 0000000000..f81b91fb46 --- /dev/null +++ b/lib/rubygems/safe_marshal/visitors/to_ruby.rb @@ -0,0 +1,266 @@ +# frozen_string_literal: true + +require_relative "visitor" + +module Gem::SafeMarshal + module Visitors + class ToRuby < Visitor + def initialize(permitted_classes:, permitted_symbols:) + @permitted_classes = permitted_classes + @permitted_symbols = permitted_symbols | permitted_classes | ["E"] + + @objects = [] + @symbols = [] + @class_cache = {} + + @stack = ["root"] + end + + def inspect # :nodoc: + format("#<%s permitted_classes: %p permitted_symbols: %p>", self.class, @permitted_classes, @permitted_symbols) + end + + def visit(target) + depth = @stack.size + super + ensure + @stack.slice!(depth.pred..) + end + + private + + def visit_Gem_SafeMarshal_Elements_Array(a) + register_object([]).replace(a.elements.each_with_index.map do |e, i| + @stack << "[#{i}]" + visit(e) + end) + end + + def visit_Gem_SafeMarshal_Elements_Symbol(s) + resolve_symbol(s.name) + end + + def map_ivars(ivars) + ivars.map.with_index do |(k, v), i| + @stack << "ivar #{i}" + k = visit(k) + @stack << k + next k, visit(v) + end + end + + def visit_Gem_SafeMarshal_Elements_WithIvars(e) + idx = 0 + object_offset = @objects.size + @stack << "object" + object = visit(e.object) + ivars = map_ivars(e.ivars) + + case e.object + when Elements::UserDefined + if object.class == ::Time + offset = zone = nano_num = nano_den = nil + ivars.reject! do |k, v| + case k + when :offset + offset = v + when :zone + zone = v + when :nano_num + nano_num = v + when :nano_den + nano_den = v + when :submicro + else + next false + end + true + end + object = object.localtime offset if offset + if (nano_den || nano_num) && !(nano_den && nano_num) + raise FormatError, "Must have all of nano_den, nano_num for Time #{e.pretty_inspect}" + elsif nano_den && nano_num + nano = Rational(nano_num, nano_den) + nsec, subnano = nano.divmod(1) + nano = nsec + subnano + + object = Time.at(object.to_r, nano, :nanosecond) + end + if zone + require "time" + Time.send(:force_zone!, object, zone, offset) + end + @objects[object_offset] = object + end + when Elements::String + enc = nil + + ivars.each do |k, v| + case k + when :E + case v + when TrueClass + enc = "UTF-8" + when FalseClass + enc = "US-ASCII" + end + else + break + end + idx += 1 + end + + object.replace ::String.new(object, encoding: enc) + end + + ivars[idx..].each do |k, v| + object.instance_variable_set k, v + end + object + end + + def visit_Gem_SafeMarshal_Elements_Hash(o) + hash = register_object({}) + + o.pairs.each_with_index do |(k, v), i| + @stack << i + k = visit(k) + @stack << k + hash[k] = visit(v) + end + + hash + end + + def visit_Gem_SafeMarshal_Elements_HashWithDefaultValue(o) + hash = visit_Gem_SafeMarshal_Elements_Hash(o) + @stack << :default + hash.default = visit(o.default) + hash + end + + def visit_Gem_SafeMarshal_Elements_Object(o) + register_object(resolve_class(o.name).allocate) + end + + def visit_Gem_SafeMarshal_Elements_ObjectLink(o) + @objects[o.offset] + end + + def visit_Gem_SafeMarshal_Elements_SymbolLink(o) + @symbols[o.offset] + end + + def visit_Gem_SafeMarshal_Elements_UserDefined(o) + register_object(resolve_class(o.name).send(:_load, o.binary_string)) + end + + def visit_Gem_SafeMarshal_Elements_UserMarshal(o) + register_object(resolve_class(o.name).allocate).tap do |object| + @stack << :data + object.marshal_load visit(o.data) + end + end + + def visit_Gem_SafeMarshal_Elements_Integer(i) + i.int + end + + def visit_Gem_SafeMarshal_Elements_Nil(_) + nil + end + + def visit_Gem_SafeMarshal_Elements_True(_) + true + end + + def visit_Gem_SafeMarshal_Elements_False(_) + false + end + + def visit_Gem_SafeMarshal_Elements_String(s) + register_object(s.str) + end + + def visit_Gem_SafeMarshal_Elements_Float(f) + case f.string + when "inf" + ::Float::INFINITY + when "-inf" + -::Float::INFINITY + when "nan" + ::Float::NAN + else + f.string.to_f + end + end + + def visit_Gem_SafeMarshal_Elements_Bignum(b) + result = 0 + b.data.each_byte.with_index do |byte, exp| + result += (byte * 2**(exp * 8)) + end + + case b.sign + when 43 # ?+ + result + when 45 # ?- + -result + else + raise FormatError, "Unexpected sign for Bignum #{b.sign.chr.inspect} (#{b.sign})" + end + end + + def resolve_class(n) + @class_cache[n] ||= begin + name = nil + case n + when Elements::Symbol, Elements::SymbolLink + @stack << "class name" + name = visit(n) + else + raise FormatError, "Class names must be Symbol or SymbolLink" + end + to_s = name.to_s + raise UnpermittedClassError.new(name: name, stack: @stack.dup) unless @permitted_classes.include?(to_s) + begin + ::Object.const_get(to_s) + rescue NameError + raise ArgumentError, "Undefined class #{to_s.inspect}" + end + end + end + + def resolve_symbol(name) + raise UnpermittedSymbolError.new(symbol: name, stack: @stack.dup) unless @permitted_symbols.include?(name) + sym = name.to_sym + @symbols << sym + sym + end + + def register_object(o) + @objects << o + o + end + + class UnpermittedSymbolError < StandardError + def initialize(symbol:, stack:) + @symbol = symbol + @stack = stack + super "Attempting to load unpermitted symbol #{symbol.inspect} @ #{stack.join "."}" + end + end + + class UnpermittedClassError < StandardError + def initialize(name:, stack:) + @name = name + @stack = stack + super "Attempting to load unpermitted class #{name.inspect} @ #{stack.join "."}" + end + end + + class FormatError < StandardError + end + end + end +end diff --git a/lib/rubygems/safe_marshal/visitors/visitor.rb b/lib/rubygems/safe_marshal/visitors/visitor.rb new file mode 100644 index 0000000000..c9a079dc0e --- /dev/null +++ b/lib/rubygems/safe_marshal/visitors/visitor.rb @@ -0,0 +1,74 @@ +# frozen_string_literal: true + +module Gem::SafeMarshal::Visitors + class Visitor + def visit(target) + send DISPATCH.fetch(target.class), target + end + + private + + DISPATCH = Gem::SafeMarshal::Elements.constants.each_with_object({}) do |c, h| + next if c == :Element + + klass = Gem::SafeMarshal::Elements.const_get(c) + h[klass] = :"visit_#{klass.name.gsub("::", "_")}" + h.default = :visit_unknown_element + end.compare_by_identity.freeze + private_constant :DISPATCH + + def visit_unknown_element(e) + raise ArgumentError, "Attempting to visit unknown element #{e.inspect}" + end + + def visit_Gem_SafeMarshal_Elements_Array(target) + target.elements.each {|e| visit(e) } + end + + def visit_Gem_SafeMarshal_Elements_Bignum(target); end + def visit_Gem_SafeMarshal_Elements_False(target); end + def visit_Gem_SafeMarshal_Elements_Float(target); end + + def visit_Gem_SafeMarshal_Elements_Hash(target) + target.pairs.each do |k, v| + visit(k) + visit(v) + end + end + + def visit_Gem_SafeMarshal_Elements_HashWithDefaultValue(target) + visit_Gem_SafeMarshal_Elements_Hash(target) + visit(target.default) + end + + def visit_Gem_SafeMarshal_Elements_Integer(target); end + def visit_Gem_SafeMarshal_Elements_Nil(target); end + + def visit_Gem_SafeMarshal_Elements_Object(target) + visit(target.name) + end + + def visit_Gem_SafeMarshal_Elements_ObjectLink(target); end + def visit_Gem_SafeMarshal_Elements_String(target); end + def visit_Gem_SafeMarshal_Elements_Symbol(target); end + def visit_Gem_SafeMarshal_Elements_SymbolLink(target); end + def visit_Gem_SafeMarshal_Elements_True(target); end + + def visit_Gem_SafeMarshal_Elements_UserDefined(target) + visit(target.name) + end + + def visit_Gem_SafeMarshal_Elements_UserMarshal(target) + visit(target.name) + visit(target.data) + end + + def visit_Gem_SafeMarshal_Elements_WithIvars(target) + visit(target.object) + target.ivars.each do |k, v| + visit(k) + visit(v) + end + end + end +end diff --git a/lib/rubygems/source.rb b/lib/rubygems/source.rb index 8b3a8828d1..7c5b746a43 100644 --- a/lib/rubygems/source.rb +++ b/lib/rubygems/source.rb @@ -135,8 +135,9 @@ class Gem::Source if File.exist? local_spec spec = Gem.read_binary local_spec + Gem.load_safe_marshal spec = begin - Marshal.load(spec) + Gem::SafeMarshal.safe_load(spec) rescue StandardError nil end @@ -157,8 +158,9 @@ class Gem::Source end end + Gem.load_safe_marshal # TODO: Investigate setting Gem::Specification#loaded_from to a URI - Marshal.load spec + Gem::SafeMarshal.safe_load spec end ## @@ -188,8 +190,9 @@ class Gem::Source spec_dump = fetcher.cache_update_path spec_path, local_file, update_cache? + Gem.load_safe_marshal begin - Gem::NameTuple.from_list Marshal.load(spec_dump) + Gem::NameTuple.from_list Gem::SafeMarshal.safe_load(spec_dump) rescue ArgumentError if update_cache? && !retried FileUtils.rm local_file diff --git a/lib/rubygems/specification.rb b/lib/rubygems/specification.rb index 0b9bf143df..8a1cd945a5 100644 --- a/lib/rubygems/specification.rb +++ b/lib/rubygems/specification.rb @@ -1300,12 +1300,13 @@ class Gem::Specification < Gem::BasicSpecification def self._load(str) Gem.load_yaml + Gem.load_safe_marshal yaml_set = false retry_count = 0 array = begin - Marshal.load str + Gem::SafeMarshal.safe_load str rescue ArgumentError => e # Avoid an infinite retry loop when the argument error has nothing to do # with the classes not being defined. diff --git a/test/rubygems/test_gem_safe_marshal.rb b/test/rubygems/test_gem_safe_marshal.rb new file mode 100644 index 0000000000..5133a63622 --- /dev/null +++ b/test/rubygems/test_gem_safe_marshal.rb @@ -0,0 +1,144 @@ +# frozen_string_literal: true + +require_relative "helper" + +require "date" +require "rubygems/safe_marshal" + +class TestGemSafeMarshal < Gem::TestCase + def test_repeated_symbol + assert_safe_load_as [:development, :development] + end + + def test_repeated_string + s = "hello" + a = [s] + assert_safe_load_as [s, a, s, a] + assert_safe_load_as [s, s] + end + + def test_recursive_string + s = String.new("hello") + s.instance_variable_set(:@type, s) + assert_safe_load_as s, additional_methods: [:instance_variables] + end + + def test_recursive_array + a = [] + a << a + assert_safe_load_as a + end + + def test_time_loads + assert_safe_load_as Time.new + end + + def test_time_with_zone_loads + assert_safe_load_as Time.now(in: "+04:00") + end + + def test_string_with_encoding + assert_safe_load_as String.new("abc", encoding: "US-ASCII") + assert_safe_load_as String.new("abc", encoding: "UTF-8") + end + + def test_string_with_ivar + assert_safe_load_as String.new("abc").tap { _1.instance_variable_set :@type, "type" } + end + + def test_time_with_ivar + assert_safe_load_as Time.new.tap { _1.instance_variable_set :@type, "type" } + end + + secs = Time.new(2000, 12, 31, 23, 59, 59).to_i + [ + Time.at(secs, 1, :millisecond), + Time.at(secs, 1.1, :millisecond), + Time.at(secs, 1.01, :millisecond), + Time.at(secs, 1, :microsecond), + Time.at(secs, 1.1, :microsecond), + Time.at(secs, 1.01, :microsecond), + Time.at(secs, 1, :nanosecond), + Time.at(secs, 1.1, :nanosecond), + Time.at(secs, 1.01, :nanosecond), + Time.at(secs, 1.001, :nanosecond), + Time.at(secs, 1.00001, :nanosecond), + Time.at(secs, 1.00001, :nanosecond).tap {|t| t.instance_variable_set :@type, "type" }, + ].each_with_index do |t, i| + define_method("test_time_#{i} #{t.inspect}") do + assert_safe_load_as t, additional_methods: [:ctime, :to_f, :to_r, :to_i, :zone, :subsec, :instance_variables, :to_a] + end + end + + def test_floats + [0.0, Float::INFINITY, Float::NAN, 1.1, 3e7].each do |f| + assert_safe_load_as f + assert_safe_load_as(-f) + end + end + + def test_hash_with_ivar + assert_safe_load_as({ runtime: :development }.tap { _1.instance_variable_set :@type, "null" }) + end + + def test_hash_with_default_value + assert_safe_load_as Hash.new([]) + end + + def test_frozen_object + assert_safe_load_as Gem::Version.new("1.abc").freeze + end + + def test_date + assert_safe_load_as Date.new + end + + [ + 0, 1, 2, 3, 4, 5, 6, 122, 123, 124, 127, 128, 255, 256, 257, + 2**16, 2**16 - 1, 2**20 - 1, + 2**28, 2**28 - 1, + 2**32, 2**32 - 1, + 2**63, 2**63 - 1 + ]. + each do |i| + define_method("test_int_ #{i}") do + assert_safe_load_as i + assert_safe_load_as(-i) + assert_safe_load_as(i + 1) + assert_safe_load_as(i - 1) + end + end + + def test_gem_spec_disallowed_symbol + e = assert_raise(Gem::SafeMarshal::Visitors::ToRuby::UnpermittedSymbolError) do + spec = Gem::Specification.new do |s| + s.name = "hi" + s.version = "1.2.3" + + s.dependencies << Gem::Dependency.new("rspec", Gem::Requirement.new([">= 1.2.3"]), :runtime).tap { _1.instance_variable_set(:@name, :rspec) } + end + Gem::SafeMarshal.safe_load(Marshal.dump(spec)) + end + + assert_equal e.message, "Attempting to load unpermitted symbol \"rspec\" @ root.[9].[0].@name" + end + + def assert_safe_load_as(x, additional_methods: []) + dumped = Marshal.dump(x) + loaded = Marshal.load(dumped) + safe_loaded = Gem::SafeMarshal.safe_load(dumped) + + # NaN != NaN, for example + if x == x # rubocop:disable Lint/BinaryOperatorWithIdenticalOperands + # assert_equal x, safe_loaded, "should load #{dumped.inspect}" + assert_equal loaded, safe_loaded, "should equal what Marshal.load returns" + end + + assert_equal x.to_s, safe_loaded.to_s, "should have equal to_s" + assert_equal x.inspect, safe_loaded.inspect, "should have equal inspect" + additional_methods.each do |m| + assert_equal loaded.send(m), safe_loaded.send(m), "should have equal #{m}" + end + assert_equal Marshal.dump(loaded), Marshal.dump(safe_loaded), "should Marshal.dump the same" + end +end |