Skip to content

Under betterFors don't drop the trailing map if it would result in a different type (also drop _ => ()) #22619

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Feb 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions compiler/src/dotty/tools/dotc/Compiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import parsing.Parser
import Phases.Phase
import transform.*
import backend.jvm.{CollectSuperCalls, GenBCode}
import localopt.StringInterpolatorOpt
import localopt.{StringInterpolatorOpt, DropForMap}

/** The central class of the dotc compiler. The job of a compiler is to create
* runs, which process given `phases` in a given `rootContext`.
Expand Down Expand Up @@ -68,7 +68,8 @@ class Compiler {
new InlineVals, // Check right hand-sides of an `inline val`s
new ExpandSAMs, // Expand single abstract method closures to anonymous classes
new ElimRepeated, // Rewrite vararg parameters and arguments
new RefChecks) :: // Various checks mostly related to abstract members and overriding
new RefChecks, // Various checks mostly related to abstract members and overriding
new DropForMap) :: // Drop unused trailing map calls in for comprehensions
List(new semanticdb.ExtractSemanticDB.AppendDiagnostics) :: // Attach warnings to extracted SemanticDB and write to .semanticdb file
List(new init.Checker) :: // Check initialization of objects
List(new ProtectedAccessors, // Add accessors for protected members
Expand Down
31 changes: 19 additions & 12 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ object desugar {
*/
val PolyFunctionApply: Property.Key[Unit] = Property.StickyKey()

/** An attachment key to indicate that an Apply is created as a last `map`
* scall in a for-comprehension.
*/
val TrailingForMap: Property.Key[Unit] = Property.StickyKey()

/** What static check should be applied to a Match? */
enum MatchCheck {
case None, Exhaustive, IrrefutablePatDef, IrrefutableGenFrom
Expand Down Expand Up @@ -1966,14 +1971,8 @@ object desugar {
*
* 3.
*
* for (P <- G) yield P ==> G
*
* If betterFors is enabled, P is a variable or a tuple of variables and G is not a withFilter.
*
* for (P <- G) yield E ==> G.map (P => E)
*
* Otherwise
*
* 4.
*
* for (P_1 <- G_1; P_2 <- G_2; ...) ...
Expand Down Expand Up @@ -2146,14 +2145,20 @@ object desugar {
case (Tuple(ts1), Tuple(ts2)) => ts1.corresponds(ts2)(deepEquals)
case _ => false

def markTrailingMap(aply: Apply, gen: GenFrom, selectName: TermName): Unit =
if betterForsEnabled
&& selectName == mapName
&& gen.checkMode != GenCheckMode.Filtered // results of withFilter have the wrong type
&& (deepEquals(gen.pat, body) || deepEquals(body, Tuple(Nil)))
then
aply.putAttachment(TrailingForMap, ())

enums match {
case Nil if betterForsEnabled => body
case (gen: GenFrom) :: Nil =>
if betterForsEnabled
&& gen.checkMode != GenCheckMode.Filtered // results of withFilter have the wrong type
&& deepEquals(gen.pat, body)
then gen.expr // avoid a redundant map with identity
else Apply(rhsSelect(gen, mapName), makeLambda(gen, body))
val aply = Apply(rhsSelect(gen, mapName), makeLambda(gen, body))
markTrailingMap(aply, gen, mapName)
aply
case (gen: GenFrom) :: (rest @ (GenFrom(_, _, _) :: _)) =>
val cont = makeFor(mapName, flatMapName, rest, body)
Apply(rhsSelect(gen, flatMapName), makeLambda(gen, cont))
Expand All @@ -2164,7 +2169,9 @@ object desugar {
val selectName =
if rest.exists(_.isInstanceOf[GenFrom]) then flatMapName
else mapName
Apply(rhsSelect(gen, selectName), makeLambda(gen, cont))
val aply = Apply(rhsSelect(gen, selectName), makeLambda(gen, cont))
markTrailingMap(aply, gen, selectName)
aply
case (gen: GenFrom) :: (rest @ GenAlias(_, _) :: _) =>
val (valeqs, rest1) = rest.span(_.isInstanceOf[GenAlias])
val pats = valeqs map { case GenAlias(pat, _) => pat }
Expand Down
54 changes: 54 additions & 0 deletions compiler/src/dotty/tools/dotc/transform/localopt/DropForMap.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package dotty.tools.dotc
package transform.localopt

import dotty.tools.dotc.ast.tpd.*
import dotty.tools.dotc.core.Decorators.*
import dotty.tools.dotc.core.Contexts.*
import dotty.tools.dotc.core.StdNames.*
import dotty.tools.dotc.core.Symbols.*
import dotty.tools.dotc.core.Types.*
import dotty.tools.dotc.transform.MegaPhase.MiniPhase
import dotty.tools.dotc.ast.desugar

/** Drop unused trailing map calls in for comprehensions.
* We can drop the map call if:
* - it won't change the type of the expression, and
* - the function is an identity function or a const function to unit.
*
* The latter condition is checked in [[Desugar.scala#makeFor]]
*/
class DropForMap extends MiniPhase:
import DropForMap.*

override def phaseName: String = DropForMap.name

override def description: String = DropForMap.description

override def transformApply(tree: Apply)(using Context): Tree =
if !tree.hasAttachment(desugar.TrailingForMap) then tree
else tree match
case aply @ Apply(MapCall(f), List(Lambda(List(param), body)))
if f.tpe =:= aply.tpe => // make sure that the type of the expression won't change
f // drop the map call
case _ =>
tree.removeAttachment(desugar.TrailingForMap)
tree

private object Lambda:
def unapply(tree: Tree)(using Context): Option[(List[ValDef], Tree)] =
tree match
case Block(List(defdef: DefDef), Closure(Nil, ref, _))
if ref.symbol == defdef.symbol && !defdef.paramss.exists(_.forall(_.isType)) =>
Some((defdef.termParamss.flatten, defdef.rhs))
case _ => None

private object MapCall:
def unapply(tree: Tree)(using Context): Option[Tree] = tree match
case Select(f, nme.map) => Some(f)
case Apply(fn, _) => unapply(fn)
case TypeApply(fn, _) => unapply(fn)
case _ => None

object DropForMap:
val name: String = "dropForMap"
val description: String = "Drop unused trailing map calls in for comprehensions"
2 changes: 1 addition & 1 deletion docs/_docs/reference/experimental/better-fors.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ Additionally this extension changes the way `for`-comprehensions are desugared.
This change makes the desugaring more intuitive and avoids unnecessary `map` calls, when an alias is not followed by a guard.

2. **Avoiding Redundant `map` Calls**:
When the result of the `for`-comprehension is the same expression as the last generator pattern, the desugaring avoids an unnecessary `map` call. but th eequality of the last pattern and the result has to be able to be checked syntactically, so it is either a variable or a tuple of variables.
When the result of the `for`-comprehension is the same expression as the last generator pattern, the desugaring avoids an unnecessary `map` call. But the equality of the last pattern and the result has to be able to be checked syntactically, so it is either a variable or a tuple of variables. There is also a special case for dropping the `map`, if its body is a constant function, that returns `()` (`Unit` constant).
**Current Desugaring**:
```scala
for {
Expand Down
13 changes: 13 additions & 0 deletions tests/pos/better-fors-i21804.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import scala.language.experimental.betterFors

case class Container[A](val value: A) {
def map[B](f: A => B): Container[B] = Container(f(value))
}

sealed trait Animal
case class Dog() extends Animal

def opOnDog(dog: Container[Dog]): Container[Animal] =
for
v <- dog
yield v
4 changes: 4 additions & 0 deletions tests/run/better-fors-map-elim.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
MySome(())
MySome(2)
MySome((2,3))
MySome((2,(3,4)))
64 changes: 64 additions & 0 deletions tests/run/better-fors-map-elim.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import scala.language.experimental.betterFors

class myOptionModule(doOnMap: => Unit) {
sealed trait MyOption[+A] {
def map[B](f: A => B): MyOption[B] = this match {
case MySome(x) => {
doOnMap
MySome(f(x))
}
case MyNone => MyNone
}
def flatMap[B](f: A => MyOption[B]): MyOption[B] = this match {
case MySome(x) => f(x)
case MyNone => MyNone
}
}
case class MySome[A](x: A) extends MyOption[A]
case object MyNone extends MyOption[Nothing]
object MyOption {
def apply[A](x: A): MyOption[A] = MySome(x)
}
}

object Test extends App {

val myOption = new myOptionModule(println("map called"))

import myOption.*

def portablePrintMyOption(opt: MyOption[Any]): Unit =
if opt == MySome(()) then
println("MySome(())")
else
println(opt)

val z = for {
a <- MyOption(1)
b <- MyOption(())
} yield ()

portablePrintMyOption(z)

val z2 = for {
a <- MyOption(1)
b <- MyOption(2)
} yield b

portablePrintMyOption(z2)

val z3 = for {
a <- MyOption(1)
(b, c) <- MyOption((2, 3))
} yield (b, c)

portablePrintMyOption(z3)

val z4 = for {
a <- MyOption(1)
(b, (c, d)) <- MyOption((2, (3, 4)))
} yield (b, (c, d))

portablePrintMyOption(z4)

}
Loading