summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKevin Newton <[email protected]>2024-10-11 14:43:23 -0400
committergit <[email protected]>2024-10-11 19:34:57 +0000
commit5f62522d5b8bd162ddf657680b8532eadeaae21f (patch)
treefa8f8e757ac296dc658cc9dd809f51ddb4fd6401
parent8aeb60aec88dd68fdfbaa75ca06e65188233ccbf (diff)
[ruby/prism] Prism::StringQuery
Introduce StringQuery to provide methods to access some metadata about the Ruby lexer. https://github.com/ruby/prism/commit/d3f55b67b9
-rw-r--r--lib/prism.rb9
-rw-r--r--lib/prism/ffi.rb39
-rw-r--r--lib/prism/prism.gemspec3
-rw-r--r--lib/prism/string_query.rb30
-rw-r--r--prism/extension.c67
-rw-r--r--prism/prism.c163
-rw-r--r--prism/prism.h47
-rw-r--r--test/prism/ruby/string_query_test.rb60
8 files changed, 414 insertions, 4 deletions
diff --git a/lib/prism.rb b/lib/prism.rb
index 66a64e7fd0..50b14a5486 100644
--- a/lib/prism.rb
+++ b/lib/prism.rb
@@ -25,6 +25,7 @@ module Prism
autoload :Pattern, "prism/pattern"
autoload :Reflection, "prism/reflection"
autoload :Serialize, "prism/serialize"
+ autoload :StringQuery, "prism/string_query"
autoload :Translation, "prism/translation"
autoload :Visitor, "prism/visitor"
@@ -75,13 +76,13 @@ require_relative "prism/parse_result"
# it's going to require the built library. Otherwise, it's going to require a
# module that uses FFI to call into the library.
if RUBY_ENGINE == "ruby" and !ENV["PRISM_FFI_BACKEND"]
- require "prism/prism"
-
# The C extension is the default backend on CRuby.
Prism::BACKEND = :CEXT
-else
- require_relative "prism/ffi"
+ require "prism/prism"
+else
# The FFI backend is used on other Ruby implementations.
Prism::BACKEND = :FFI
+
+ require_relative "prism/ffi"
end
diff --git a/lib/prism/ffi.rb b/lib/prism/ffi.rb
index 0520f7cdd2..5caae440f4 100644
--- a/lib/prism/ffi.rb
+++ b/lib/prism/ffi.rb
@@ -73,6 +73,7 @@ module Prism
callback :pm_parse_stream_fgets_t, [:pointer, :int, :pointer], :pointer
enum :pm_string_init_result_t, %i[PM_STRING_INIT_SUCCESS PM_STRING_INIT_ERROR_GENERIC PM_STRING_INIT_ERROR_DIRECTORY]
+ enum :pm_string_query_t, [:PM_STRING_QUERY_ERROR, -1, :PM_STRING_QUERY_FALSE, :PM_STRING_QUERY_TRUE]
load_exported_functions_from(
"prism.h",
@@ -83,6 +84,9 @@ module Prism
"pm_serialize_lex",
"pm_serialize_parse_lex",
"pm_parse_success_p",
+ "pm_string_query_local",
+ "pm_string_query_constant",
+ "pm_string_query_method_name",
[:pm_parse_stream_fgets_t]
)
@@ -492,4 +496,39 @@ module Prism
values.pack(template)
end
end
+
+ # Here we are going to patch StringQuery to put in the class-level methods so
+ # that it can maintain a consistent interface
+ class StringQuery
+ class << self
+ # Mirrors the C extension's StringQuery::local? method.
+ def local?(string)
+ query(LibRubyParser.pm_string_query_local(string, string.bytesize, string.encoding.name))
+ end
+
+ # Mirrors the C extension's StringQuery::constant? method.
+ def constant?(string)
+ query(LibRubyParser.pm_string_query_constant(string, string.bytesize, string.encoding.name))
+ end
+
+ # Mirrors the C extension's StringQuery::method_name? method.
+ def method_name?(string)
+ query(LibRubyParser.pm_string_query_method_name(string, string.bytesize, string.encoding.name))
+ end
+
+ private
+
+ # Parse the enum result and return an appropriate boolean.
+ def query(result)
+ case result
+ when :PM_STRING_QUERY_ERROR
+ raise ArgumentError, "Invalid or non ascii-compatible encoding"
+ when :PM_STRING_QUERY_FALSE
+ false
+ when :PM_STRING_QUERY_TRUE
+ true
+ end
+ end
+ end
+ end
end
diff --git a/lib/prism/prism.gemspec b/lib/prism/prism.gemspec
index 37aa979576..6123b71fc8 100644
--- a/lib/prism/prism.gemspec
+++ b/lib/prism/prism.gemspec
@@ -89,6 +89,7 @@ Gem::Specification.new do |spec|
"lib/prism/polyfill/unpack1.rb",
"lib/prism/reflection.rb",
"lib/prism/serialize.rb",
+ "lib/prism/string_query.rb",
"lib/prism/translation.rb",
"lib/prism/translation/parser.rb",
"lib/prism/translation/parser33.rb",
@@ -109,6 +110,7 @@ Gem::Specification.new do |spec|
"rbi/prism/node.rbi",
"rbi/prism/parse_result.rbi",
"rbi/prism/reflection.rbi",
+ "rbi/prism/string_query.rbi",
"rbi/prism/translation/parser.rbi",
"rbi/prism/translation/parser33.rbi",
"rbi/prism/translation/parser34.rbi",
@@ -129,6 +131,7 @@ Gem::Specification.new do |spec|
"sig/prism/pattern.rbs",
"sig/prism/reflection.rbs",
"sig/prism/serialize.rbs",
+ "sig/prism/string_query.rbs",
"sig/prism/visitor.rbs",
"src/diagnostic.c",
"src/encoding.c",
diff --git a/lib/prism/string_query.rb b/lib/prism/string_query.rb
new file mode 100644
index 0000000000..9011051d2b
--- /dev/null
+++ b/lib/prism/string_query.rb
@@ -0,0 +1,30 @@
+# frozen_string_literal: true
+
+module Prism
+ # Query methods that allow categorizing strings based on their context for
+ # where they could be valid in a Ruby syntax tree.
+ class StringQuery
+ # The string that this query is wrapping.
+ attr_reader :string
+
+ # Initialize a new query with the given string.
+ def initialize(string)
+ @string = string
+ end
+
+ # Whether or not this string is a valid local variable name.
+ def local?
+ StringQuery.local?(string)
+ end
+
+ # Whether or not this string is a valid constant name.
+ def constant?
+ StringQuery.constant?(string)
+ end
+
+ # Whether or not this string is a valid method name.
+ def method_name?
+ StringQuery.method_name?(string)
+ end
+ end
+end
diff --git a/prism/extension.c b/prism/extension.c
index 93fa7b0989..47603fd9b4 100644
--- a/prism/extension.c
+++ b/prism/extension.c
@@ -23,6 +23,7 @@ VALUE rb_cPrismResult;
VALUE rb_cPrismParseResult;
VALUE rb_cPrismLexResult;
VALUE rb_cPrismParseLexResult;
+VALUE rb_cPrismStringQuery;
VALUE rb_cPrismDebugEncoding;
@@ -1134,6 +1135,67 @@ parse_file_failure_p(int argc, VALUE *argv, VALUE self) {
}
/******************************************************************************/
+/* String query methods */
+/******************************************************************************/
+
+/**
+ * Process the result of a call to a string query method and return an
+ * appropriate value.
+ */
+static VALUE
+string_query(pm_string_query_t result) {
+ switch (result) {
+ case PM_STRING_QUERY_ERROR:
+ rb_raise(rb_eArgError, "Invalid or non ascii-compatible encoding");
+ return Qfalse;
+ case PM_STRING_QUERY_FALSE:
+ return Qfalse;
+ case PM_STRING_QUERY_TRUE:
+ return Qtrue;
+ }
+}
+
+/**
+ * call-seq:
+ * Prism::StringQuery::local?(string) -> bool
+ *
+ * Returns true if the string constitutes a valid local variable name. Note that
+ * this means the names that can be set through Binding#local_variable_set, not
+ * necessarily the ones that can be set through a local variable assignment.
+ */
+static VALUE
+string_query_local_p(VALUE self, VALUE string) {
+ const uint8_t *source = (const uint8_t *) check_string(string);
+ return string_query(pm_string_query_local(source, RSTRING_LEN(string), rb_enc_get(string)->name));
+}
+
+/**
+ * call-seq:
+ * Prism::StringQuery::constant?(string) -> bool
+ *
+ * Returns true if the string constitutes a valid constant name. Note that this
+ * means the names that can be set through Module#const_set, not necessarily the
+ * ones that can be set through a constant assignment.
+ */
+static VALUE
+string_query_constant_p(VALUE self, VALUE string) {
+ const uint8_t *source = (const uint8_t *) check_string(string);
+ return string_query(pm_string_query_constant(source, RSTRING_LEN(string), rb_enc_get(string)->name));
+}
+
+/**
+ * call-seq:
+ * Prism::StringQuery::method_name?(string) -> bool
+ *
+ * Returns true if the string constitutes a valid method name.
+ */
+static VALUE
+string_query_method_name_p(VALUE self, VALUE string) {
+ const uint8_t *source = (const uint8_t *) check_string(string);
+ return string_query(pm_string_query_method_name(source, RSTRING_LEN(string), rb_enc_get(string)->name));
+}
+
+/******************************************************************************/
/* Initialization of the extension */
/******************************************************************************/
@@ -1170,6 +1232,7 @@ Init_prism(void) {
rb_cPrismParseResult = rb_define_class_under(rb_cPrism, "ParseResult", rb_cPrismResult);
rb_cPrismLexResult = rb_define_class_under(rb_cPrism, "LexResult", rb_cPrismResult);
rb_cPrismParseLexResult = rb_define_class_under(rb_cPrism, "ParseLexResult", rb_cPrismResult);
+ rb_cPrismStringQuery = rb_define_class_under(rb_cPrism, "StringQuery", rb_cObject);
// Intern all of the IDs eagerly that we support so that we don't have to do
// it every time we parse.
@@ -1211,6 +1274,10 @@ Init_prism(void) {
rb_define_singleton_method(rb_cPrism, "dump_file", dump_file, -1);
#endif
+ rb_define_singleton_method(rb_cPrismStringQuery, "local?", string_query_local_p, 1);
+ rb_define_singleton_method(rb_cPrismStringQuery, "constant?", string_query_constant_p, 1);
+ rb_define_singleton_method(rb_cPrismStringQuery, "method_name?", string_query_method_name_p, 1);
+
// Next, initialize the other APIs.
Init_prism_api_node();
Init_prism_pack();
diff --git a/prism/prism.c b/prism/prism.c
index 00485b68ad..07d188dcc8 100644
--- a/prism/prism.c
+++ b/prism/prism.c
@@ -22642,3 +22642,166 @@ pm_serialize_parse_comments(pm_buffer_t *buffer, const uint8_t *source, size_t s
}
#endif
+
+/******************************************************************************/
+/* Slice queries for the Ruby API */
+/******************************************************************************/
+
+/** The category of slice returned from pm_slice_type. */
+typedef enum {
+ /** Returned when the given encoding name is invalid. */
+ PM_SLICE_TYPE_ERROR = -1,
+
+ /** Returned when no other types apply to the slice. */
+ PM_SLICE_TYPE_NONE,
+
+ /** Returned when the slice is a valid local variable name. */
+ PM_SLICE_TYPE_LOCAL,
+
+ /** Returned when the slice is a valid constant name. */
+ PM_SLICE_TYPE_CONSTANT,
+
+ /** Returned when the slice is a valid method name. */
+ PM_SLICE_TYPE_METHOD_NAME
+} pm_slice_type_t;
+
+/**
+ * Check that the slice is a valid local variable name or constant.
+ */
+pm_slice_type_t
+pm_slice_type(const uint8_t *source, size_t length, const char *encoding_name) {
+ // first, get the right encoding object
+ const pm_encoding_t *encoding = pm_encoding_find((const uint8_t *) encoding_name, (const uint8_t *) (encoding_name + strlen(encoding_name)));
+ if (encoding == NULL) return PM_SLICE_TYPE_ERROR;
+
+ // check that there is at least one character
+ if (length == 0) return PM_SLICE_TYPE_NONE;
+
+ size_t width;
+ if ((width = encoding->alpha_char(source, (ptrdiff_t) length)) != 0) {
+ // valid because alphabetical
+ } else if (*source == '_') {
+ // valid because underscore
+ width = 1;
+ } else if ((*source >= 0x80) && ((width = encoding->char_width(source, (ptrdiff_t) length)) > 0)) {
+ // valid because multibyte
+ } else {
+ // invalid because no match
+ return PM_SLICE_TYPE_NONE;
+ }
+
+ // determine the type of the slice based on the first character
+ const uint8_t *end = source + length;
+ pm_slice_type_t result = encoding->isupper_char(source, end - source) ? PM_SLICE_TYPE_CONSTANT : PM_SLICE_TYPE_LOCAL;
+
+ // next, iterate through all of the bytes of the string to ensure that they
+ // are all valid identifier characters
+ source += width;
+
+ while (source < end) {
+ if ((width = encoding->alnum_char(source, end - source)) != 0) {
+ // valid because alphanumeric
+ source += width;
+ } else if (*source == '_') {
+ // valid because underscore
+ source++;
+ } else if ((*source >= 0x80) && ((width = encoding->char_width(source, end - source)) > 0)) {
+ // valid because multibyte
+ source += width;
+ } else {
+ // invalid because no match
+ break;
+ }
+ }
+
+ // accept a ! or ? at the end of the slice as a method name
+ if (*source == '!' || *source == '?' || *source == '=') {
+ source++;
+ result = PM_SLICE_TYPE_METHOD_NAME;
+ }
+
+ // valid if we are at the end of the slice
+ return source == end ? result : PM_SLICE_TYPE_NONE;
+}
+
+/**
+ * Check that the slice is a valid local variable name.
+ */
+PRISM_EXPORTED_FUNCTION pm_string_query_t
+pm_string_query_local(const uint8_t *source, size_t length, const char *encoding_name) {
+ switch (pm_slice_type(source, length, encoding_name)) {
+ case PM_SLICE_TYPE_ERROR:
+ return PM_STRING_QUERY_ERROR;
+ case PM_SLICE_TYPE_NONE:
+ case PM_SLICE_TYPE_CONSTANT:
+ case PM_SLICE_TYPE_METHOD_NAME:
+ return PM_STRING_QUERY_FALSE;
+ case PM_SLICE_TYPE_LOCAL:
+ return PM_STRING_QUERY_TRUE;
+ }
+
+ assert(false && "unreachable");
+ return PM_STRING_QUERY_FALSE;
+}
+
+/**
+ * Check that the slice is a valid constant name.
+ */
+PRISM_EXPORTED_FUNCTION pm_string_query_t
+pm_string_query_constant(const uint8_t *source, size_t length, const char *encoding_name) {
+ switch (pm_slice_type(source, length, encoding_name)) {
+ case PM_SLICE_TYPE_ERROR:
+ return PM_STRING_QUERY_ERROR;
+ case PM_SLICE_TYPE_NONE:
+ case PM_SLICE_TYPE_LOCAL:
+ case PM_SLICE_TYPE_METHOD_NAME:
+ return PM_STRING_QUERY_FALSE;
+ case PM_SLICE_TYPE_CONSTANT:
+ return PM_STRING_QUERY_TRUE;
+ }
+
+ assert(false && "unreachable");
+ return PM_STRING_QUERY_FALSE;
+}
+
+/**
+ * Check that the slice is a valid method name.
+ */
+PRISM_EXPORTED_FUNCTION pm_string_query_t
+pm_string_query_method_name(const uint8_t *source, size_t length, const char *encoding_name) {
+#define B(p) ((p) ? PM_STRING_QUERY_TRUE : PM_STRING_QUERY_FALSE)
+#define C1(c) (*source == c)
+#define C2(s) (memcmp(source, s, 2) == 0)
+#define C3(s) (memcmp(source, s, 3) == 0)
+
+ switch (pm_slice_type(source, length, encoding_name)) {
+ case PM_SLICE_TYPE_ERROR:
+ return PM_STRING_QUERY_ERROR;
+ case PM_SLICE_TYPE_NONE:
+ break;
+ case PM_SLICE_TYPE_LOCAL:
+ // numbered parameters are not valid method names
+ return B((length != 2) || (source[0] != '_') || (source[1] == '0') || !pm_char_is_decimal_digit(source[1]));
+ case PM_SLICE_TYPE_CONSTANT:
+ // all constants are valid method names
+ case PM_SLICE_TYPE_METHOD_NAME:
+ // all method names are valid method names
+ return PM_STRING_QUERY_TRUE;
+ }
+
+ switch (length) {
+ case 1:
+ return B(C1('&') || C1('`') || C1('!') || C1('^') || C1('>') || C1('<') || C1('-') || C1('%') || C1('|') || C1('+') || C1('/') || C1('*') || C1('~'));
+ case 2:
+ return B(C2("!=") || C2("!~") || C2("[]") || C2("==") || C2("=~") || C2(">=") || C2(">>") || C2("<=") || C2("<<") || C2("**"));
+ case 3:
+ return B(C3("===") || C3("<=>") || C3("[]="));
+ default:
+ return PM_STRING_QUERY_FALSE;
+ }
+
+#undef B
+#undef C1
+#undef C2
+#undef C3
+}
diff --git a/prism/prism.h b/prism/prism.h
index 755c38fca2..6f7b850a31 100644
--- a/prism/prism.h
+++ b/prism/prism.h
@@ -235,6 +235,53 @@ PRISM_EXPORTED_FUNCTION void pm_dump_json(pm_buffer_t *buffer, const pm_parser_t
#endif
/**
+ * Represents the results of a slice query.
+ */
+typedef enum {
+ /** Returned if the encoding given to a slice query was invalid. */
+ PM_STRING_QUERY_ERROR = -1,
+
+ /** Returned if the result of the slice query is false. */
+ PM_STRING_QUERY_FALSE,
+
+ /** Returned if the result of the slice query is true. */
+ PM_STRING_QUERY_TRUE
+} pm_string_query_t;
+
+/**
+ * Check that the slice is a valid local variable name.
+ *
+ * @param source The source to check.
+ * @param length The length of the source.
+ * @param encoding_name The name of the encoding of the source.
+ * @return PM_STRING_QUERY_TRUE if the query is true, PM_STRING_QUERY_FALSE if
+ * the query is false, and PM_STRING_QUERY_ERROR if the encoding was invalid.
+ */
+PRISM_EXPORTED_FUNCTION pm_string_query_t pm_string_query_local(const uint8_t *source, size_t length, const char *encoding_name);
+
+/**
+ * Check that the slice is a valid constant name.
+ *
+ * @param source The source to check.
+ * @param length The length of the source.
+ * @param encoding_name The name of the encoding of the source.
+ * @return PM_STRING_QUERY_TRUE if the query is true, PM_STRING_QUERY_FALSE if
+ * the query is false, and PM_STRING_QUERY_ERROR if the encoding was invalid.
+ */
+PRISM_EXPORTED_FUNCTION pm_string_query_t pm_string_query_constant(const uint8_t *source, size_t length, const char *encoding_name);
+
+/**
+ * Check that the slice is a valid method name.
+ *
+ * @param source The source to check.
+ * @param length The length of the source.
+ * @param encoding_name The name of the encoding of the source.
+ * @return PM_STRING_QUERY_TRUE if the query is true, PM_STRING_QUERY_FALSE if
+ * the query is false, and PM_STRING_QUERY_ERROR if the encoding was invalid.
+ */
+PRISM_EXPORTED_FUNCTION pm_string_query_t pm_string_query_method_name(const uint8_t *source, size_t length, const char *encoding_name);
+
+/**
* @mainpage
*
* Prism is a parser for the Ruby programming language. It is designed to be
diff --git a/test/prism/ruby/string_query_test.rb b/test/prism/ruby/string_query_test.rb
new file mode 100644
index 0000000000..aa50c10ff3
--- /dev/null
+++ b/test/prism/ruby/string_query_test.rb
@@ -0,0 +1,60 @@
+# frozen_string_literal: true
+
+require_relative "../test_helper"
+
+module Prism
+ class StringQueryTest < TestCase
+ def test_local?
+ assert_predicate StringQuery.new("a"), :local?
+ assert_predicate StringQuery.new("a1"), :local?
+ assert_predicate StringQuery.new("self"), :local?
+
+ assert_predicate StringQuery.new("_a"), :local?
+ assert_predicate StringQuery.new("_1"), :local?
+
+ assert_predicate StringQuery.new("😀"), :local?
+ assert_predicate StringQuery.new("ã‚¢".encode("Windows-31J")), :local?
+
+ refute_predicate StringQuery.new("1"), :local?
+ refute_predicate StringQuery.new("A"), :local?
+ end
+
+ def test_constant?
+ assert_predicate StringQuery.new("A"), :constant?
+ assert_predicate StringQuery.new("A1"), :constant?
+ assert_predicate StringQuery.new("A_B"), :constant?
+ assert_predicate StringQuery.new("BEGIN"), :constant?
+
+ assert_predicate StringQuery.new("À"), :constant?
+ assert_predicate StringQuery.new("A".encode("US-ASCII")), :constant?
+
+ refute_predicate StringQuery.new("a"), :constant?
+ refute_predicate StringQuery.new("1"), :constant?
+ end
+
+ def test_method_name?
+ assert_predicate StringQuery.new("a"), :method_name?
+ assert_predicate StringQuery.new("A"), :method_name?
+ assert_predicate StringQuery.new("__FILE__"), :method_name?
+
+ assert_predicate StringQuery.new("a?"), :method_name?
+ assert_predicate StringQuery.new("a!"), :method_name?
+ assert_predicate StringQuery.new("a="), :method_name?
+
+ assert_predicate StringQuery.new("+"), :method_name?
+ assert_predicate StringQuery.new("<<"), :method_name?
+ assert_predicate StringQuery.new("==="), :method_name?
+
+ assert_predicate StringQuery.new("_0"), :method_name?
+
+ refute_predicate StringQuery.new("1"), :method_name?
+ refute_predicate StringQuery.new("_1"), :method_name?
+ end
+
+ def test_invalid_encoding
+ assert_raise ArgumentError do
+ StringQuery.new("A".encode("UTF-16LE")).local?
+ end
+ end
+ end
+end