summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSamuel Giddins <[email protected]>2023-08-18 13:35:23 -0700
committergit <[email protected]>2023-09-20 02:02:58 +0000
commitd182d83ce929cd322f4a6fd134cd31be950eca77 (patch)
tree18808133999aab348990d35cae6920d0252dfa56
parentc47608494f961d2a8fe24b1a7b7f627b305cf7fe (diff)
[rubygems/rubygems] Add a Marshal.load replacement that walks an AST to safely load permitted classes/symbols
https://github.com/rubygems/rubygems/commit/7e4478fe73
-rw-r--r--lib/rubygems.rb10
-rw-r--r--lib/rubygems/indexer.rb3
-rw-r--r--lib/rubygems/safe_marshal.rb71
-rw-r--r--lib/rubygems/safe_marshal/elements.rb138
-rw-r--r--lib/rubygems/safe_marshal/reader.rb182
-rw-r--r--lib/rubygems/safe_marshal/visitors/to_ruby.rb266
-rw-r--r--lib/rubygems/safe_marshal/visitors/visitor.rb74
-rw-r--r--lib/rubygems/source.rb9
-rw-r--r--lib/rubygems/specification.rb3
-rw-r--r--test/rubygems/test_gem_safe_marshal.rb144
10 files changed, 895 insertions, 5 deletions
diff --git a/lib/rubygems.rb b/lib/rubygems.rb
index cb71657018..557d7aa8eb 100644
--- a/lib/rubygems.rb
+++ b/lib/rubygems.rb
@@ -604,6 +604,16 @@ An Array (#{env.inspect}) was passed in from #{caller[3]}
@yaml_loaded = true
end
+ @safe_marshal_loaded = false
+
+ def self.load_safe_marshal
+ return if @safe_marshal_loaded
+
+ require_relative "rubygems/safe_marshal"
+
+ @safe_marshal_loaded = true
+ end
+
##
# The file name and line number of the caller of the caller of this method.
#
diff --git a/lib/rubygems/indexer.rb b/lib/rubygems/indexer.rb
index c6691517b3..f4c981b9ef 100644
--- a/lib/rubygems/indexer.rb
+++ b/lib/rubygems/indexer.rb
@@ -411,7 +411,8 @@ class Gem::Indexer
# +dest+. For a latest index, does not ensure the new file is minimal.
def update_specs_index(index, source, dest)
- specs_index = Marshal.load Gem.read_binary(source)
+ Gem.load_safe_marshal
+ specs_index = Gem::SafeMarshal.safe_load Gem.read_binary(source)
index.each do |spec|
platform = spec.original_platform
diff --git a/lib/rubygems/safe_marshal.rb b/lib/rubygems/safe_marshal.rb
new file mode 100644
index 0000000000..172b92d5a6
--- /dev/null
+++ b/lib/rubygems/safe_marshal.rb
@@ -0,0 +1,71 @@
+# frozen_string_literal: true
+
+require_relative "safe_marshal/reader"
+require_relative "safe_marshal/visitors/to_ruby"
+
+module Gem
+ ###
+ # This module is used for safely loading Marshal specs from a gem. The
+ # `safe_load` method defined on this module is specifically designed for
+ # loading Gem specifications.
+
+ module SafeMarshal
+ PERMITTED_CLASSES = %w[
+ Time
+ Date
+
+ Gem::Dependency
+ Gem::NameTuple
+ Gem::Platform
+ Gem::Requirement
+ Gem::Specification
+ Gem::Version
+ Gem::Version::Requirement
+
+ YAML::Syck::DefaultKey
+ YAML::PrivateType
+ ].freeze
+ private_constant :PERMITTED_CLASSES
+
+ PERMITTED_SYMBOLS = %w[
+ E
+
+ offset
+ zone
+ nano_num
+ nano_den
+ submicro
+
+ @_zone
+ @cpu
+ @force_ruby_platform
+ @marshal_with_utc_coercion
+ @name
+ @os
+ @platform
+ @prerelease
+ @requirement
+ @taguri
+ @type
+ @type_id
+ @value
+ @version
+ @version_requirement
+ @version_requirements
+
+ development
+ runtime
+ ].freeze
+ private_constant :PERMITTED_SYMBOLS
+
+ def self.safe_load(input)
+ load(input, permitted_classes: PERMITTED_CLASSES, permitted_symbols: PERMITTED_SYMBOLS)
+ end
+
+ def self.load(input, permitted_classes: [::Symbol], permitted_symbols: [])
+ root = Reader.new(StringIO.new(input, "r")).read!
+
+ Visitors::ToRuby.new(permitted_classes: permitted_classes, permitted_symbols: permitted_symbols).visit(root)
+ end
+ end
+end
diff --git a/lib/rubygems/safe_marshal/elements.rb b/lib/rubygems/safe_marshal/elements.rb
new file mode 100644
index 0000000000..70961c40ff
--- /dev/null
+++ b/lib/rubygems/safe_marshal/elements.rb
@@ -0,0 +1,138 @@
+# frozen_string_literal: true
+
+module Gem
+ module SafeMarshal
+ module Elements
+ class Element
+ end
+
+ class Symbol < Element
+ def initialize(name:)
+ @name = name
+ end
+ attr_reader :name
+ end
+
+ class UserDefined < Element
+ def initialize(name:, binary_string:)
+ @name = name
+ @binary_string = binary_string
+ end
+
+ attr_reader :name, :binary_string
+ end
+
+ class UserMarshal < Element
+ def initialize(name:, data:)
+ @name = name
+ @data = data
+ end
+
+ attr_reader :name, :data
+ end
+
+ class String < Element
+ def initialize(str:)
+ @str = str
+ end
+
+ attr_reader :str
+ end
+
+ class Hash < Element
+ def initialize(pairs:)
+ @pairs = pairs
+ end
+
+ attr_reader :pairs
+ end
+
+ class HashWithDefaultValue < Hash
+ def initialize(default:, **kwargs)
+ super(**kwargs)
+ @default = default
+ end
+
+ attr_reader :default
+ end
+
+ class Array < Element
+ def initialize(elements:)
+ @elements = elements
+ end
+
+ attr_reader :elements
+ end
+
+ class Integer < Element
+ def initialize(int:)
+ @int = int
+ end
+
+ attr_reader :int
+ end
+
+ class True < Element
+ def initialize
+ end
+ TRUE = new.freeze
+ end
+
+ class False < Element
+ def initialize
+ end
+
+ FALSE = new.freeze
+ end
+
+ class WithIvars < Element
+ def initialize(object:,ivars:)
+ @object = object
+ @ivars = ivars
+ end
+
+ attr_reader :object, :ivars
+ end
+
+ class Object < Element
+ def initialize(name:)
+ @name = name
+ end
+ attr_reader :name
+ end
+
+ class Nil < Element
+ NIL = new.freeze
+ end
+
+ class ObjectLink < Element
+ def initialize(offset:)
+ @offset = offset
+ end
+ attr_reader :offset
+ end
+
+ class SymbolLink < Element
+ def initialize(offset:)
+ @offset = offset
+ end
+ attr_reader :offset
+ end
+
+ class Float < Element
+ def initialize(string:)
+ @string = string
+ end
+ attr_reader :string
+ end
+
+ class Bignum < Element # rubocop:disable Lint/UnifiedInteger
+ def initialize(sign:, data:)
+ @sign = sign
+ @data = data
+ end
+ attr_reader :sign, :data
+ end
+ end
+ end
+end
diff --git a/lib/rubygems/safe_marshal/reader.rb b/lib/rubygems/safe_marshal/reader.rb
new file mode 100644
index 0000000000..105984ff04
--- /dev/null
+++ b/lib/rubygems/safe_marshal/reader.rb
@@ -0,0 +1,182 @@
+# frozen_string_literal: true
+
+require_relative "elements"
+
+module Gem
+ module SafeMarshal
+ class Reader
+ class UnconsumedBytesError < StandardError
+ end
+
+ def initialize(io)
+ @io = io
+ end
+
+ def read!
+ read_header
+ root = read_element
+ raise UnconsumedBytesError unless @io.eof?
+ root
+ end
+
+ private
+
+ MARSHAL_VERSION = [Marshal::MAJOR_VERSION, Marshal::MINOR_VERSION].map(&:chr).join.freeze
+ private_constant :MARSHAL_VERSION
+
+ def read_header
+ v = @io.read(2)
+ raise "Unsupported marshal version #{v.inspect}, expected #{MARSHAL_VERSION.inspect}" unless v == MARSHAL_VERSION
+ end
+
+ def read_byte
+ @io.getbyte
+ end
+
+ def read_integer
+ b = read_byte
+
+ case b
+ when 0x00
+ 0
+ when 0x01
+ @io.read(1).unpack1("C")
+ when 0x02
+ @io.read(2).unpack1("S<")
+ when 0x03
+ (@io.read(3) + "\0").unpack1("L<")
+ when 0x04
+ @io.read(4).unpack1("L<")
+ when 0xFC
+ @io.read(4).unpack1("L<") | -0x100000000
+ when 0xFD
+ (@io.read(3) + "\0").unpack1("L<") | -0x1000000
+ when 0xFE
+ @io.read(2).unpack1("s<") | -0x10000
+ when 0xFF
+ read_byte | -0x100
+ else
+ signed = (b ^ 128) - 128
+ if b >= 128
+ signed + 5
+ else
+ signed - 5
+ end
+ end
+ end
+
+ def read_element
+ type = read_byte
+ case type
+ when 34 then read_string # ?"
+ when 48 then read_nil # ?0
+ when 58 then read_symbol # ?:
+ when 59 then read_symbol_link # ?;
+ when 64 then read_object_link # ?@
+ when 70 then read_false # ?F
+ when 73 then read_object_with_ivars # ?I
+ when 84 then read_true # ?T
+ when 85 then read_user_marshal # ?U
+ when 91 then read_array # ?[
+ when 102 then read_float # ?f
+ when 105 then Elements::Integer.new int: read_integer # ?i
+ when 108 then read_bignum
+ when 111 then read_object # ?o
+ when 117 then read_user_defined # ?u
+ when 123 then read_hash # ?{
+ when 125 then read_hash_with_default_value # ?}
+ when "e".ord then read_extended_object
+ when "c".ord then read_class
+ when "m".ord then read_module
+ when "M".ord then read_class_or_module
+ when "d".ord then read_data
+ when "/".ord then read_regexp
+ when "S".ord then read_struct
+ when "C".ord then read_user_class
+ else
+ raise "Unsupported marshal type discriminator #{type.chr.inspect} (#{type})"
+ end
+ end
+
+ def read_symbol
+ Elements::Symbol.new name: @io.read(read_integer)
+ end
+
+ def read_string
+ Elements::String.new(str: @io.read(read_integer))
+ end
+
+ def read_true
+ Elements::True::TRUE
+ end
+
+ def read_false
+ Elements::False::FALSE
+ end
+
+ def read_user_defined
+ Elements::UserDefined.new(name: read_element, binary_string: @io.read(read_integer))
+ end
+
+ def read_array
+ Elements::Array.new(elements: Array.new(read_integer) do |_i|
+ read_element
+ end)
+ end
+
+ def read_object_with_ivars
+ Elements::WithIvars.new(object: read_element, ivars:
+ Array.new(read_integer) do
+ [read_element, read_element]
+ end)
+ end
+
+ def read_symbol_link
+ Elements::SymbolLink.new offset: read_integer
+ end
+
+ def read_user_marshal
+ Elements::UserMarshal.new(name: read_element, data: read_element)
+ end
+
+ def read_object_link
+ Elements::ObjectLink.new(offset: read_integer)
+ end
+
+ def read_hash
+ pairs = Array.new(read_integer) do
+ [read_element, read_element]
+ end
+ Elements::Hash.new(pairs: pairs)
+ end
+
+ def read_hash_with_default_value
+ pairs = Array.new(read_integer) do
+ [read_element, read_element]
+ end
+ Elements::HashWithDefaultValue.new(pairs: pairs, default: read_element)
+ end
+
+ def read_object
+ Elements::WithIvars.new(
+ object: Elements::Object.new(name: read_element),
+ ivars: Array.new(read_integer) do
+ [read_element, read_element]
+ end
+ )
+ end
+
+ def read_nil
+ Elements::Nil::NIL
+ end
+
+ def read_float
+ Elements::Float.new string: @io.read(read_integer)
+ end
+
+ def read_bignum
+ Elements::Bignum.new(sign: read_byte, data: @io.read(read_integer * 2))
+ end
+ end
+ end
+end
diff --git a/lib/rubygems/safe_marshal/visitors/to_ruby.rb b/lib/rubygems/safe_marshal/visitors/to_ruby.rb
new file mode 100644
index 0000000000..f81b91fb46
--- /dev/null
+++ b/lib/rubygems/safe_marshal/visitors/to_ruby.rb
@@ -0,0 +1,266 @@
+# frozen_string_literal: true
+
+require_relative "visitor"
+
+module Gem::SafeMarshal
+ module Visitors
+ class ToRuby < Visitor
+ def initialize(permitted_classes:, permitted_symbols:)
+ @permitted_classes = permitted_classes
+ @permitted_symbols = permitted_symbols | permitted_classes | ["E"]
+
+ @objects = []
+ @symbols = []
+ @class_cache = {}
+
+ @stack = ["root"]
+ end
+
+ def inspect # :nodoc:
+ format("#<%s permitted_classes: %p permitted_symbols: %p>", self.class, @permitted_classes, @permitted_symbols)
+ end
+
+ def visit(target)
+ depth = @stack.size
+ super
+ ensure
+ @stack.slice!(depth.pred..)
+ end
+
+ private
+
+ def visit_Gem_SafeMarshal_Elements_Array(a)
+ register_object([]).replace(a.elements.each_with_index.map do |e, i|
+ @stack << "[#{i}]"
+ visit(e)
+ end)
+ end
+
+ def visit_Gem_SafeMarshal_Elements_Symbol(s)
+ resolve_symbol(s.name)
+ end
+
+ def map_ivars(ivars)
+ ivars.map.with_index do |(k, v), i|
+ @stack << "ivar #{i}"
+ k = visit(k)
+ @stack << k
+ next k, visit(v)
+ end
+ end
+
+ def visit_Gem_SafeMarshal_Elements_WithIvars(e)
+ idx = 0
+ object_offset = @objects.size
+ @stack << "object"
+ object = visit(e.object)
+ ivars = map_ivars(e.ivars)
+
+ case e.object
+ when Elements::UserDefined
+ if object.class == ::Time
+ offset = zone = nano_num = nano_den = nil
+ ivars.reject! do |k, v|
+ case k
+ when :offset
+ offset = v
+ when :zone
+ zone = v
+ when :nano_num
+ nano_num = v
+ when :nano_den
+ nano_den = v
+ when :submicro
+ else
+ next false
+ end
+ true
+ end
+ object = object.localtime offset if offset
+ if (nano_den || nano_num) && !(nano_den && nano_num)
+ raise FormatError, "Must have all of nano_den, nano_num for Time #{e.pretty_inspect}"
+ elsif nano_den && nano_num
+ nano = Rational(nano_num, nano_den)
+ nsec, subnano = nano.divmod(1)
+ nano = nsec + subnano
+
+ object = Time.at(object.to_r, nano, :nanosecond)
+ end
+ if zone
+ require "time"
+ Time.send(:force_zone!, object, zone, offset)
+ end
+ @objects[object_offset] = object
+ end
+ when Elements::String
+ enc = nil
+
+ ivars.each do |k, v|
+ case k
+ when :E
+ case v
+ when TrueClass
+ enc = "UTF-8"
+ when FalseClass
+ enc = "US-ASCII"
+ end
+ else
+ break
+ end
+ idx += 1
+ end
+
+ object.replace ::String.new(object, encoding: enc)
+ end
+
+ ivars[idx..].each do |k, v|
+ object.instance_variable_set k, v
+ end
+ object
+ end
+
+ def visit_Gem_SafeMarshal_Elements_Hash(o)
+ hash = register_object({})
+
+ o.pairs.each_with_index do |(k, v), i|
+ @stack << i
+ k = visit(k)
+ @stack << k
+ hash[k] = visit(v)
+ end
+
+ hash
+ end
+
+ def visit_Gem_SafeMarshal_Elements_HashWithDefaultValue(o)
+ hash = visit_Gem_SafeMarshal_Elements_Hash(o)
+ @stack << :default
+ hash.default = visit(o.default)
+ hash
+ end
+
+ def visit_Gem_SafeMarshal_Elements_Object(o)
+ register_object(resolve_class(o.name).allocate)
+ end
+
+ def visit_Gem_SafeMarshal_Elements_ObjectLink(o)
+ @objects[o.offset]
+ end
+
+ def visit_Gem_SafeMarshal_Elements_SymbolLink(o)
+ @symbols[o.offset]
+ end
+
+ def visit_Gem_SafeMarshal_Elements_UserDefined(o)
+ register_object(resolve_class(o.name).send(:_load, o.binary_string))
+ end
+
+ def visit_Gem_SafeMarshal_Elements_UserMarshal(o)
+ register_object(resolve_class(o.name).allocate).tap do |object|
+ @stack << :data
+ object.marshal_load visit(o.data)
+ end
+ end
+
+ def visit_Gem_SafeMarshal_Elements_Integer(i)
+ i.int
+ end
+
+ def visit_Gem_SafeMarshal_Elements_Nil(_)
+ nil
+ end
+
+ def visit_Gem_SafeMarshal_Elements_True(_)
+ true
+ end
+
+ def visit_Gem_SafeMarshal_Elements_False(_)
+ false
+ end
+
+ def visit_Gem_SafeMarshal_Elements_String(s)
+ register_object(s.str)
+ end
+
+ def visit_Gem_SafeMarshal_Elements_Float(f)
+ case f.string
+ when "inf"
+ ::Float::INFINITY
+ when "-inf"
+ -::Float::INFINITY
+ when "nan"
+ ::Float::NAN
+ else
+ f.string.to_f
+ end
+ end
+
+ def visit_Gem_SafeMarshal_Elements_Bignum(b)
+ result = 0
+ b.data.each_byte.with_index do |byte, exp|
+ result += (byte * 2**(exp * 8))
+ end
+
+ case b.sign
+ when 43 # ?+
+ result
+ when 45 # ?-
+ -result
+ else
+ raise FormatError, "Unexpected sign for Bignum #{b.sign.chr.inspect} (#{b.sign})"
+ end
+ end
+
+ def resolve_class(n)
+ @class_cache[n] ||= begin
+ name = nil
+ case n
+ when Elements::Symbol, Elements::SymbolLink
+ @stack << "class name"
+ name = visit(n)
+ else
+ raise FormatError, "Class names must be Symbol or SymbolLink"
+ end
+ to_s = name.to_s
+ raise UnpermittedClassError.new(name: name, stack: @stack.dup) unless @permitted_classes.include?(to_s)
+ begin
+ ::Object.const_get(to_s)
+ rescue NameError
+ raise ArgumentError, "Undefined class #{to_s.inspect}"
+ end
+ end
+ end
+
+ def resolve_symbol(name)
+ raise UnpermittedSymbolError.new(symbol: name, stack: @stack.dup) unless @permitted_symbols.include?(name)
+ sym = name.to_sym
+ @symbols << sym
+ sym
+ end
+
+ def register_object(o)
+ @objects << o
+ o
+ end
+
+ class UnpermittedSymbolError < StandardError
+ def initialize(symbol:, stack:)
+ @symbol = symbol
+ @stack = stack
+ super "Attempting to load unpermitted symbol #{symbol.inspect} @ #{stack.join "."}"
+ end
+ end
+
+ class UnpermittedClassError < StandardError
+ def initialize(name:, stack:)
+ @name = name
+ @stack = stack
+ super "Attempting to load unpermitted class #{name.inspect} @ #{stack.join "."}"
+ end
+ end
+
+ class FormatError < StandardError
+ end
+ end
+ end
+end
diff --git a/lib/rubygems/safe_marshal/visitors/visitor.rb b/lib/rubygems/safe_marshal/visitors/visitor.rb
new file mode 100644
index 0000000000..c9a079dc0e
--- /dev/null
+++ b/lib/rubygems/safe_marshal/visitors/visitor.rb
@@ -0,0 +1,74 @@
+# frozen_string_literal: true
+
+module Gem::SafeMarshal::Visitors
+ class Visitor
+ def visit(target)
+ send DISPATCH.fetch(target.class), target
+ end
+
+ private
+
+ DISPATCH = Gem::SafeMarshal::Elements.constants.each_with_object({}) do |c, h|
+ next if c == :Element
+
+ klass = Gem::SafeMarshal::Elements.const_get(c)
+ h[klass] = :"visit_#{klass.name.gsub("::", "_")}"
+ h.default = :visit_unknown_element
+ end.compare_by_identity.freeze
+ private_constant :DISPATCH
+
+ def visit_unknown_element(e)
+ raise ArgumentError, "Attempting to visit unknown element #{e.inspect}"
+ end
+
+ def visit_Gem_SafeMarshal_Elements_Array(target)
+ target.elements.each {|e| visit(e) }
+ end
+
+ def visit_Gem_SafeMarshal_Elements_Bignum(target); end
+ def visit_Gem_SafeMarshal_Elements_False(target); end
+ def visit_Gem_SafeMarshal_Elements_Float(target); end
+
+ def visit_Gem_SafeMarshal_Elements_Hash(target)
+ target.pairs.each do |k, v|
+ visit(k)
+ visit(v)
+ end
+ end
+
+ def visit_Gem_SafeMarshal_Elements_HashWithDefaultValue(target)
+ visit_Gem_SafeMarshal_Elements_Hash(target)
+ visit(target.default)
+ end
+
+ def visit_Gem_SafeMarshal_Elements_Integer(target); end
+ def visit_Gem_SafeMarshal_Elements_Nil(target); end
+
+ def visit_Gem_SafeMarshal_Elements_Object(target)
+ visit(target.name)
+ end
+
+ def visit_Gem_SafeMarshal_Elements_ObjectLink(target); end
+ def visit_Gem_SafeMarshal_Elements_String(target); end
+ def visit_Gem_SafeMarshal_Elements_Symbol(target); end
+ def visit_Gem_SafeMarshal_Elements_SymbolLink(target); end
+ def visit_Gem_SafeMarshal_Elements_True(target); end
+
+ def visit_Gem_SafeMarshal_Elements_UserDefined(target)
+ visit(target.name)
+ end
+
+ def visit_Gem_SafeMarshal_Elements_UserMarshal(target)
+ visit(target.name)
+ visit(target.data)
+ end
+
+ def visit_Gem_SafeMarshal_Elements_WithIvars(target)
+ visit(target.object)
+ target.ivars.each do |k, v|
+ visit(k)
+ visit(v)
+ end
+ end
+ end
+end
diff --git a/lib/rubygems/source.rb b/lib/rubygems/source.rb
index 8b3a8828d1..7c5b746a43 100644
--- a/lib/rubygems/source.rb
+++ b/lib/rubygems/source.rb
@@ -135,8 +135,9 @@ class Gem::Source
if File.exist? local_spec
spec = Gem.read_binary local_spec
+ Gem.load_safe_marshal
spec = begin
- Marshal.load(spec)
+ Gem::SafeMarshal.safe_load(spec)
rescue StandardError
nil
end
@@ -157,8 +158,9 @@ class Gem::Source
end
end
+ Gem.load_safe_marshal
# TODO: Investigate setting Gem::Specification#loaded_from to a URI
- Marshal.load spec
+ Gem::SafeMarshal.safe_load spec
end
##
@@ -188,8 +190,9 @@ class Gem::Source
spec_dump = fetcher.cache_update_path spec_path, local_file, update_cache?
+ Gem.load_safe_marshal
begin
- Gem::NameTuple.from_list Marshal.load(spec_dump)
+ Gem::NameTuple.from_list Gem::SafeMarshal.safe_load(spec_dump)
rescue ArgumentError
if update_cache? && !retried
FileUtils.rm local_file
diff --git a/lib/rubygems/specification.rb b/lib/rubygems/specification.rb
index 0b9bf143df..8a1cd945a5 100644
--- a/lib/rubygems/specification.rb
+++ b/lib/rubygems/specification.rb
@@ -1300,12 +1300,13 @@ class Gem::Specification < Gem::BasicSpecification
def self._load(str)
Gem.load_yaml
+ Gem.load_safe_marshal
yaml_set = false
retry_count = 0
array = begin
- Marshal.load str
+ Gem::SafeMarshal.safe_load str
rescue ArgumentError => e
# Avoid an infinite retry loop when the argument error has nothing to do
# with the classes not being defined.
diff --git a/test/rubygems/test_gem_safe_marshal.rb b/test/rubygems/test_gem_safe_marshal.rb
new file mode 100644
index 0000000000..5133a63622
--- /dev/null
+++ b/test/rubygems/test_gem_safe_marshal.rb
@@ -0,0 +1,144 @@
+# frozen_string_literal: true
+
+require_relative "helper"
+
+require "date"
+require "rubygems/safe_marshal"
+
+class TestGemSafeMarshal < Gem::TestCase
+ def test_repeated_symbol
+ assert_safe_load_as [:development, :development]
+ end
+
+ def test_repeated_string
+ s = "hello"
+ a = [s]
+ assert_safe_load_as [s, a, s, a]
+ assert_safe_load_as [s, s]
+ end
+
+ def test_recursive_string
+ s = String.new("hello")
+ s.instance_variable_set(:@type, s)
+ assert_safe_load_as s, additional_methods: [:instance_variables]
+ end
+
+ def test_recursive_array
+ a = []
+ a << a
+ assert_safe_load_as a
+ end
+
+ def test_time_loads
+ assert_safe_load_as Time.new
+ end
+
+ def test_time_with_zone_loads
+ assert_safe_load_as Time.now(in: "+04:00")
+ end
+
+ def test_string_with_encoding
+ assert_safe_load_as String.new("abc", encoding: "US-ASCII")
+ assert_safe_load_as String.new("abc", encoding: "UTF-8")
+ end
+
+ def test_string_with_ivar
+ assert_safe_load_as String.new("abc").tap { _1.instance_variable_set :@type, "type" }
+ end
+
+ def test_time_with_ivar
+ assert_safe_load_as Time.new.tap { _1.instance_variable_set :@type, "type" }
+ end
+
+ secs = Time.new(2000, 12, 31, 23, 59, 59).to_i
+ [
+ Time.at(secs, 1, :millisecond),
+ Time.at(secs, 1.1, :millisecond),
+ Time.at(secs, 1.01, :millisecond),
+ Time.at(secs, 1, :microsecond),
+ Time.at(secs, 1.1, :microsecond),
+ Time.at(secs, 1.01, :microsecond),
+ Time.at(secs, 1, :nanosecond),
+ Time.at(secs, 1.1, :nanosecond),
+ Time.at(secs, 1.01, :nanosecond),
+ Time.at(secs, 1.001, :nanosecond),
+ Time.at(secs, 1.00001, :nanosecond),
+ Time.at(secs, 1.00001, :nanosecond).tap {|t| t.instance_variable_set :@type, "type" },
+ ].each_with_index do |t, i|
+ define_method("test_time_#{i} #{t.inspect}") do
+ assert_safe_load_as t, additional_methods: [:ctime, :to_f, :to_r, :to_i, :zone, :subsec, :instance_variables, :to_a]
+ end
+ end
+
+ def test_floats
+ [0.0, Float::INFINITY, Float::NAN, 1.1, 3e7].each do |f|
+ assert_safe_load_as f
+ assert_safe_load_as(-f)
+ end
+ end
+
+ def test_hash_with_ivar
+ assert_safe_load_as({ runtime: :development }.tap { _1.instance_variable_set :@type, "null" })
+ end
+
+ def test_hash_with_default_value
+ assert_safe_load_as Hash.new([])
+ end
+
+ def test_frozen_object
+ assert_safe_load_as Gem::Version.new("1.abc").freeze
+ end
+
+ def test_date
+ assert_safe_load_as Date.new
+ end
+
+ [
+ 0, 1, 2, 3, 4, 5, 6, 122, 123, 124, 127, 128, 255, 256, 257,
+ 2**16, 2**16 - 1, 2**20 - 1,
+ 2**28, 2**28 - 1,
+ 2**32, 2**32 - 1,
+ 2**63, 2**63 - 1
+ ].
+ each do |i|
+ define_method("test_int_ #{i}") do
+ assert_safe_load_as i
+ assert_safe_load_as(-i)
+ assert_safe_load_as(i + 1)
+ assert_safe_load_as(i - 1)
+ end
+ end
+
+ def test_gem_spec_disallowed_symbol
+ e = assert_raise(Gem::SafeMarshal::Visitors::ToRuby::UnpermittedSymbolError) do
+ spec = Gem::Specification.new do |s|
+ s.name = "hi"
+ s.version = "1.2.3"
+
+ s.dependencies << Gem::Dependency.new("rspec", Gem::Requirement.new([">= 1.2.3"]), :runtime).tap { _1.instance_variable_set(:@name, :rspec) }
+ end
+ Gem::SafeMarshal.safe_load(Marshal.dump(spec))
+ end
+
+ assert_equal e.message, "Attempting to load unpermitted symbol \"rspec\" @ root.[9].[0].@name"
+ end
+
+ def assert_safe_load_as(x, additional_methods: [])
+ dumped = Marshal.dump(x)
+ loaded = Marshal.load(dumped)
+ safe_loaded = Gem::SafeMarshal.safe_load(dumped)
+
+ # NaN != NaN, for example
+ if x == x # rubocop:disable Lint/BinaryOperatorWithIdenticalOperands
+ # assert_equal x, safe_loaded, "should load #{dumped.inspect}"
+ assert_equal loaded, safe_loaded, "should equal what Marshal.load returns"
+ end
+
+ assert_equal x.to_s, safe_loaded.to_s, "should have equal to_s"
+ assert_equal x.inspect, safe_loaded.inspect, "should have equal inspect"
+ additional_methods.each do |m|
+ assert_equal loaded.send(m), safe_loaded.send(m), "should have equal #{m}"
+ end
+ assert_equal Marshal.dump(loaded), Marshal.dump(safe_loaded), "should Marshal.dump the same"
+ end
+end