diff --git a/compiler/src/dotty/tools/dotc/Compiler.scala b/compiler/src/dotty/tools/dotc/Compiler.scala index 8f22a761e790..6aab7d54d59e 100644 --- a/compiler/src/dotty/tools/dotc/Compiler.scala +++ b/compiler/src/dotty/tools/dotc/Compiler.scala @@ -8,7 +8,7 @@ import parsing.Parser import Phases.Phase import transform.* import backend.jvm.{CollectSuperCalls, GenBCode} -import localopt.StringInterpolatorOpt +import localopt.{StringInterpolatorOpt, DropForMap} /** The central class of the dotc compiler. The job of a compiler is to create * runs, which process given `phases` in a given `rootContext`. @@ -68,7 +68,8 @@ class Compiler { new InlineVals, // Check right hand-sides of an `inline val`s new ExpandSAMs, // Expand single abstract method closures to anonymous classes new ElimRepeated, // Rewrite vararg parameters and arguments - new RefChecks) :: // Various checks mostly related to abstract members and overriding + new RefChecks, // Various checks mostly related to abstract members and overriding + new DropForMap) :: // Drop unused trailing map calls in for comprehensions List(new semanticdb.ExtractSemanticDB.AppendDiagnostics) :: // Attach warnings to extracted SemanticDB and write to .semanticdb file List(new init.Checker) :: // Check initialization of objects List(new ProtectedAccessors, // Add accessors for protected members diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index ec65224ac93d..2d0d6d25b190 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -64,6 +64,11 @@ object desugar { */ val PolyFunctionApply: Property.Key[Unit] = Property.StickyKey() + /** An attachment key to indicate that an Apply is created as a last `map` + * scall in a for-comprehension. + */ + val TrailingForMap: Property.Key[Unit] = Property.StickyKey() + /** What static check should be applied to a Match? */ enum MatchCheck { case None, Exhaustive, IrrefutablePatDef, IrrefutableGenFrom @@ -1966,14 +1971,8 @@ object desugar { * * 3. * - * for (P <- G) yield P ==> G - * - * If betterFors is enabled, P is a variable or a tuple of variables and G is not a withFilter. - * * for (P <- G) yield E ==> G.map (P => E) * - * Otherwise - * * 4. * * for (P_1 <- G_1; P_2 <- G_2; ...) ... @@ -2146,14 +2145,20 @@ object desugar { case (Tuple(ts1), Tuple(ts2)) => ts1.corresponds(ts2)(deepEquals) case _ => false + def markTrailingMap(aply: Apply, gen: GenFrom, selectName: TermName): Unit = + if betterForsEnabled + && selectName == mapName + && gen.checkMode != GenCheckMode.Filtered // results of withFilter have the wrong type + && (deepEquals(gen.pat, body) || deepEquals(body, Tuple(Nil))) + then + aply.putAttachment(TrailingForMap, ()) + enums match { case Nil if betterForsEnabled => body case (gen: GenFrom) :: Nil => - if betterForsEnabled - && gen.checkMode != GenCheckMode.Filtered // results of withFilter have the wrong type - && deepEquals(gen.pat, body) - then gen.expr // avoid a redundant map with identity - else Apply(rhsSelect(gen, mapName), makeLambda(gen, body)) + val aply = Apply(rhsSelect(gen, mapName), makeLambda(gen, body)) + markTrailingMap(aply, gen, mapName) + aply case (gen: GenFrom) :: (rest @ (GenFrom(_, _, _) :: _)) => val cont = makeFor(mapName, flatMapName, rest, body) Apply(rhsSelect(gen, flatMapName), makeLambda(gen, cont)) @@ -2164,7 +2169,9 @@ object desugar { val selectName = if rest.exists(_.isInstanceOf[GenFrom]) then flatMapName else mapName - Apply(rhsSelect(gen, selectName), makeLambda(gen, cont)) + val aply = Apply(rhsSelect(gen, selectName), makeLambda(gen, cont)) + markTrailingMap(aply, gen, selectName) + aply case (gen: GenFrom) :: (rest @ GenAlias(_, _) :: _) => val (valeqs, rest1) = rest.span(_.isInstanceOf[GenAlias]) val pats = valeqs map { case GenAlias(pat, _) => pat } diff --git a/compiler/src/dotty/tools/dotc/transform/localopt/DropForMap.scala b/compiler/src/dotty/tools/dotc/transform/localopt/DropForMap.scala new file mode 100644 index 000000000000..f7594f041204 --- /dev/null +++ b/compiler/src/dotty/tools/dotc/transform/localopt/DropForMap.scala @@ -0,0 +1,54 @@ +package dotty.tools.dotc +package transform.localopt + +import dotty.tools.dotc.ast.tpd.* +import dotty.tools.dotc.core.Decorators.* +import dotty.tools.dotc.core.Contexts.* +import dotty.tools.dotc.core.StdNames.* +import dotty.tools.dotc.core.Symbols.* +import dotty.tools.dotc.core.Types.* +import dotty.tools.dotc.transform.MegaPhase.MiniPhase +import dotty.tools.dotc.ast.desugar + +/** Drop unused trailing map calls in for comprehensions. + * We can drop the map call if: + * - it won't change the type of the expression, and + * - the function is an identity function or a const function to unit. + * + * The latter condition is checked in [[Desugar.scala#makeFor]] + */ +class DropForMap extends MiniPhase: + import DropForMap.* + + override def phaseName: String = DropForMap.name + + override def description: String = DropForMap.description + + override def transformApply(tree: Apply)(using Context): Tree = + if !tree.hasAttachment(desugar.TrailingForMap) then tree + else tree match + case aply @ Apply(MapCall(f), List(Lambda(List(param), body))) + if f.tpe =:= aply.tpe => // make sure that the type of the expression won't change + f // drop the map call + case _ => + tree.removeAttachment(desugar.TrailingForMap) + tree + + private object Lambda: + def unapply(tree: Tree)(using Context): Option[(List[ValDef], Tree)] = + tree match + case Block(List(defdef: DefDef), Closure(Nil, ref, _)) + if ref.symbol == defdef.symbol && !defdef.paramss.exists(_.forall(_.isType)) => + Some((defdef.termParamss.flatten, defdef.rhs)) + case _ => None + + private object MapCall: + def unapply(tree: Tree)(using Context): Option[Tree] = tree match + case Select(f, nme.map) => Some(f) + case Apply(fn, _) => unapply(fn) + case TypeApply(fn, _) => unapply(fn) + case _ => None + +object DropForMap: + val name: String = "dropForMap" + val description: String = "Drop unused trailing map calls in for comprehensions" diff --git a/docs/_docs/reference/experimental/better-fors.md b/docs/_docs/reference/experimental/better-fors.md index a4c42c9fb380..4f910259aab2 100644 --- a/docs/_docs/reference/experimental/better-fors.md +++ b/docs/_docs/reference/experimental/better-fors.md @@ -60,7 +60,7 @@ Additionally this extension changes the way `for`-comprehensions are desugared. This change makes the desugaring more intuitive and avoids unnecessary `map` calls, when an alias is not followed by a guard. 2. **Avoiding Redundant `map` Calls**: - When the result of the `for`-comprehension is the same expression as the last generator pattern, the desugaring avoids an unnecessary `map` call. but th eequality of the last pattern and the result has to be able to be checked syntactically, so it is either a variable or a tuple of variables. + When the result of the `for`-comprehension is the same expression as the last generator pattern, the desugaring avoids an unnecessary `map` call. But the equality of the last pattern and the result has to be able to be checked syntactically, so it is either a variable or a tuple of variables. There is also a special case for dropping the `map`, if its body is a constant function, that returns `()` (`Unit` constant). **Current Desugaring**: ```scala for { diff --git a/tests/pos/better-fors-i21804.scala b/tests/pos/better-fors-i21804.scala new file mode 100644 index 000000000000..7c8c753bf7c3 --- /dev/null +++ b/tests/pos/better-fors-i21804.scala @@ -0,0 +1,13 @@ +import scala.language.experimental.betterFors + +case class Container[A](val value: A) { + def map[B](f: A => B): Container[B] = Container(f(value)) +} + +sealed trait Animal +case class Dog() extends Animal + +def opOnDog(dog: Container[Dog]): Container[Animal] = + for + v <- dog + yield v diff --git a/tests/run/better-fors-map-elim.check b/tests/run/better-fors-map-elim.check new file mode 100644 index 000000000000..0ef3447a47c4 --- /dev/null +++ b/tests/run/better-fors-map-elim.check @@ -0,0 +1,4 @@ +MySome(()) +MySome(2) +MySome((2,3)) +MySome((2,(3,4))) diff --git a/tests/run/better-fors-map-elim.scala b/tests/run/better-fors-map-elim.scala new file mode 100644 index 000000000000..653984bc8e28 --- /dev/null +++ b/tests/run/better-fors-map-elim.scala @@ -0,0 +1,64 @@ +import scala.language.experimental.betterFors + +class myOptionModule(doOnMap: => Unit) { + sealed trait MyOption[+A] { + def map[B](f: A => B): MyOption[B] = this match { + case MySome(x) => { + doOnMap + MySome(f(x)) + } + case MyNone => MyNone + } + def flatMap[B](f: A => MyOption[B]): MyOption[B] = this match { + case MySome(x) => f(x) + case MyNone => MyNone + } + } + case class MySome[A](x: A) extends MyOption[A] + case object MyNone extends MyOption[Nothing] + object MyOption { + def apply[A](x: A): MyOption[A] = MySome(x) + } +} + +object Test extends App { + + val myOption = new myOptionModule(println("map called")) + + import myOption.* + + def portablePrintMyOption(opt: MyOption[Any]): Unit = + if opt == MySome(()) then + println("MySome(())") + else + println(opt) + + val z = for { + a <- MyOption(1) + b <- MyOption(()) + } yield () + + portablePrintMyOption(z) + + val z2 = for { + a <- MyOption(1) + b <- MyOption(2) + } yield b + + portablePrintMyOption(z2) + + val z3 = for { + a <- MyOption(1) + (b, c) <- MyOption((2, 3)) + } yield (b, c) + + portablePrintMyOption(z3) + + val z4 = for { + a <- MyOption(1) + (b, (c, d)) <- MyOption((2, (3, 4))) + } yield (b, (c, d)) + + portablePrintMyOption(z4) + +}