Correct references to composer in key calls

The transform to redirect the references to the injected
composer parameter incorrectly traversed into non-inline
lambda causing them to capture the outer composer.

Fixes: b/174030267
Test: ComposerParamTransformTests
Change-Id: Icfeb4c4a7c28caad4bd5e5e87b9eb6399ca93a67
diff --git a/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/AbstractControlFlowTransformTests.kt b/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/AbstractControlFlowTransformTests.kt
index 031b0f0..bb79bea 100644
--- a/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/AbstractControlFlowTransformTests.kt
+++ b/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/AbstractControlFlowTransformTests.kt
@@ -53,6 +53,6 @@
             var b = 2
             var c = 3
         """.trimIndent(),
-        dumpTree
+        dumpTree = dumpTree
     )
 }
\ No newline at end of file
diff --git a/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/AbstractIrTransformTest.kt b/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/AbstractIrTransformTest.kt
index 6f1108c..de2790c 100644
--- a/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/AbstractIrTransformTest.kt
+++ b/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/AbstractIrTransformTest.kt
@@ -34,6 +34,7 @@
 import org.jetbrains.kotlin.descriptors.ModuleDescriptor
 import org.jetbrains.kotlin.descriptors.konan.DeserializedKlibModuleOrigin
 import org.jetbrains.kotlin.descriptors.konan.KlibModuleOrigin
+import org.jetbrains.kotlin.ir.IrElement
 import org.jetbrains.kotlin.ir.backend.jvm.serialization.EmptyLoggingContext
 import org.jetbrains.kotlin.ir.backend.jvm.serialization.JvmIrLinker
 import org.jetbrains.kotlin.ir.backend.jvm.serialization.JvmManglerDesc
@@ -150,7 +151,7 @@
             source,
             expectedTransformed,
             "",
-            dumpTree
+            dumpTree = dumpTree
         )
     }
 
@@ -158,6 +159,7 @@
         source: String,
         expectedTransformed: String,
         extra: String = "",
+        validator: (element: IrElement) -> Unit = { },
         dumpTree: Boolean = false
     ) {
         val files = listOf(
@@ -166,8 +168,10 @@
         )
         val irModule = generateIrModuleWithJvmResolve(files)
         val keySet = mutableListOf<Int>()
+        fun IrElement.validate(): IrElement = this.also { validator(it) }
         val actualTransformed = irModule
             .files[0]
+            .validate()
             .dumpSrc()
             .replace('$', '%')
             // replace source keys for start group calls
diff --git a/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/ClassStabilityTransformTests.kt b/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/ClassStabilityTransformTests.kt
index 91c9176..702cfe0 100644
--- a/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/ClassStabilityTransformTests.kt
+++ b/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/ClassStabilityTransformTests.kt
@@ -1009,6 +1009,6 @@
         checked,
         expectedTransformed,
         unchecked,
-        dumpTree
+        dumpTree = dumpTree
     )
 }
\ No newline at end of file
diff --git a/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/ComposerParamTransformTests.kt b/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/ComposerParamTransformTests.kt
index 3b25765..07fa0354 100644
--- a/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/ComposerParamTransformTests.kt
+++ b/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/ComposerParamTransformTests.kt
@@ -16,12 +16,18 @@
 
 package androidx.compose.compiler.plugins.kotlin
 
+import org.jetbrains.kotlin.ir.IrElement
+import org.jetbrains.kotlin.ir.declarations.IrSimpleFunction
+import org.jetbrains.kotlin.ir.declarations.IrValueParameter
+import org.jetbrains.kotlin.ir.expressions.IrGetValue
+import org.jetbrains.kotlin.ir.visitors.IrElementVisitorVoid
 import org.junit.Test
 
 class ComposerParamTransformTests : ComposeIrTransformTest() {
     private fun composerParam(
         source: String,
         expectedTransformed: String,
+        validator: (element: IrElement) -> Unit = { },
         dumpTree: Boolean = false
     ) = verifyComposeIrTransform(
         """
@@ -42,6 +48,7 @@
         """.trimIndent(),
         expectedTransformed,
         "",
+        validator,
         dumpTree
     )
 
@@ -332,6 +339,126 @@
     )
 
     @Test
+    fun testKeyCall() {
+        composerParam(
+            """
+                import androidx.compose.runtime.key
+
+                @Composable
+                fun Wrapper(block: @Composable () -> Unit) {
+                    block()
+                }
+
+                @Composable
+                fun Leaf(text: String) { }
+
+                @Composable
+                fun Test(value: Int) {
+                    key(value) {
+                        Wrapper {
+                            Leaf("Value ${'$'}value")
+                        }
+                    }
+                }
+            """,
+            """
+                @Composable
+                fun Wrapper(block: Function2<Composer<*>, Int, Unit>, %composer: Composer<*>?, %changed: Int) {
+                  %composer.startRestartGroup(<>, "C(Wrapper)<block(...>:Test.kt#2487m")
+                  val %dirty = %changed
+                  if (%changed and 0b1110 === 0) {
+                    %dirty = %dirty or if (%composer.changed(block)) 0b0100 else 0b0010
+                  }
+                  if (%dirty and 0b1011 xor 0b0010 !== 0 || !%composer.skipping) {
+                    block(%composer, 0b1110 and %dirty)
+                  } else {
+                    %composer.skipToGroupEnd()
+                  }
+                  %composer.endRestartGroup()?.updateScope { %composer: Composer<*>?, %force: Int ->
+                    Wrapper(block, %composer, %changed or 0b0001)
+                  }
+                }
+                @Composable
+                fun Leaf(text: String, %composer: Composer<*>?, %changed: Int) {
+                  %composer.startRestartGroup(<>, "C(Leaf):Test.kt#2487m")
+                  val %dirty = %changed
+                  if (%changed and 0b1110 === 0) {
+                    %dirty = %dirty or if (%composer.changed(text)) 0b0100 else 0b0010
+                  }
+                  if (%dirty and 0b1011 xor 0b0010 !== 0 || !%composer.skipping) {
+                  } else {
+                    %composer.skipToGroupEnd()
+                  }
+                  %composer.endRestartGroup()?.updateScope { %composer: Composer<*>?, %force: Int ->
+                    Leaf(text, %composer, %changed or 0b0001)
+                  }
+                }
+                @Composable
+                fun Test(value: Int, %composer: Composer<*>?, %changed: Int) {
+                  %composer.startRestartGroup(<>, "C(Test):Test.kt#2487m")
+                  val %dirty = %changed
+                  if (%changed and 0b1110 === 0) {
+                    %dirty = %dirty or if (%composer.changed(value)) 0b0100 else 0b0010
+                  }
+                  if (%dirty and 0b1011 xor 0b0010 !== 0 || !%composer.skipping) {
+                    %composer.startMovableGroup(<>, value, "<Wrappe...>")
+                    Wrapper(composableLambda(%composer, <>, true, "C<Leaf("...>:Test.kt#2487m") { %composer: Composer<*>?, %changed: Int ->
+                      if (%changed and 0b1011 xor 0b0010 !== 0 || !%composer.skipping) {
+                        Leaf("Value %value", %composer, 0)
+                      } else {
+                        %composer.skipToGroupEnd()
+                      }
+                    }, %composer, 0b0110)
+                    %composer.endMovableGroup()
+                  } else {
+                    %composer.skipToGroupEnd()
+                  }
+                  %composer.endRestartGroup()?.updateScope { %composer: Composer<*>?, %force: Int ->
+                    Test(value, %composer, %changed or 0b0001)
+                  }
+                }
+            """,
+            validator = { element ->
+                // Validate that no composers are captured by nested lambdas
+                var currentComposer: IrValueParameter? = null
+                element.accept(
+                    object : IrElementVisitorVoid {
+                        override fun visitSimpleFunction(declaration: IrSimpleFunction) {
+                            val composer = declaration.valueParameters.firstOrNull {
+                                it.name == KtxNameConventions.COMPOSER_PARAMETER
+                            }
+                            val oldComposer = currentComposer
+                            if (composer != null) currentComposer = composer
+                            super.visitSimpleFunction(declaration)
+                            currentComposer = oldComposer
+                        }
+
+                        override fun visitElement(element: IrElement) {
+                            element.acceptChildren(this, null)
+                        }
+
+                        override fun visitGetValue(expression: IrGetValue) {
+                            super.visitGetValue(expression)
+                            val value = expression.symbol.owner
+                            if (
+                                value is IrValueParameter && value.name ==
+                                KtxNameConventions.COMPOSER_PARAMETER
+                            ) {
+                                assertEquals(
+                                    "Composer unexpectedly captured",
+                                    currentComposer,
+                                    value
+                                )
+                            }
+                        }
+                    },
+                    null
+                )
+            }
+        )
+    }
+
+    @Test
     fun testComposableNestedCall() {
         composerParam(
             """
diff --git a/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/DefaultParamTransformTests.kt b/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/DefaultParamTransformTests.kt
index 060a92ea..7aa99d4 100644
--- a/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/DefaultParamTransformTests.kt
+++ b/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/DefaultParamTransformTests.kt
@@ -38,7 +38,7 @@
 
             $unchecked
         """.trimIndent(),
-        dumpTree
+        dumpTree = dumpTree
     )
 
     @Test
diff --git a/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/FunctionBodySkippingTransformTests.kt b/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/FunctionBodySkippingTransformTests.kt
index 12ff304..b460a1d 100644
--- a/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/FunctionBodySkippingTransformTests.kt
+++ b/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/FunctionBodySkippingTransformTests.kt
@@ -37,7 +37,7 @@
 
             $unchecked
         """.trimIndent(),
-        dumpTree
+        dumpTree = dumpTree
     )
 
     @Test
diff --git a/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/LiveLiteralTransformTests.kt b/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/LiveLiteralTransformTests.kt
index 98d855c..60241c9 100644
--- a/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/LiveLiteralTransformTests.kt
+++ b/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/LiveLiteralTransformTests.kt
@@ -660,6 +660,6 @@
             import androidx.compose.runtime.Composable
             $unchecked
         """.trimIndent(),
-        dumpTree
+        dumpTree = dumpTree
     )
 }
\ No newline at end of file
diff --git a/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/RememberIntrinsicTransformTests.kt b/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/RememberIntrinsicTransformTests.kt
index f0780f7..524d05e 100644
--- a/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/RememberIntrinsicTransformTests.kt
+++ b/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/RememberIntrinsicTransformTests.kt
@@ -38,7 +38,7 @@
 
             $unchecked
         """.trimIndent(),
-        dumpTree
+        dumpTree = dumpTree
     )
 
     @Test
diff --git a/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/StabilityPropagationTransformTests.kt b/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/StabilityPropagationTransformTests.kt
index 30d9daf..adcbad3 100644
--- a/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/StabilityPropagationTransformTests.kt
+++ b/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/StabilityPropagationTransformTests.kt
@@ -39,7 +39,7 @@
 
             $unchecked
         """.trimIndent(),
-        dumpTree
+        dumpTree = dumpTree
     )
 
     @Test
diff --git a/compose/compiler/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/lower/ComposableFunctionBodyTransformer.kt b/compose/compiler/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/lower/ComposableFunctionBodyTransformer.kt
index cef3a82..f38d76b 100644
--- a/compose/compiler/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/lower/ComposableFunctionBodyTransformer.kt
+++ b/compose/compiler/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/lower/ComposableFunctionBodyTransformer.kt
@@ -2766,8 +2766,16 @@
         }
 
         // now after the inner block is extracted, the $composer parameter used in the block needs
-        // to be remapped to the outer composer instead.
+        // to be remapped to the outer composer instead for the expression and any inlined lambdas.
         block.transformChildrenVoid(object : IrElementTransformerVoid() {
+            override fun visitFunction(declaration: IrFunction): IrStatement {
+                if (declaration.isInlinedLambda()) {
+                    return super.visitFunction(declaration)
+                } else {
+                    return declaration
+                }
+            }
+
             override fun visitGetValue(expression: IrGetValue): IrExpression {
                 super.visitGetValue(expression)
 
diff --git a/compose/ui/ui-tooling/src/androidTest/java/androidx/compose/ui/tooling/BoundsTest.kt b/compose/ui/ui-tooling/src/androidTest/java/androidx/compose/ui/tooling/BoundsTest.kt
index 2e0c7e8..5317df3 100644
--- a/compose/ui/ui-tooling/src/androidTest/java/androidx/compose/ui/tooling/BoundsTest.kt
+++ b/compose/ui/ui-tooling/src/androidTest/java/androidx/compose/ui/tooling/BoundsTest.kt
@@ -36,7 +36,6 @@
 import androidx.test.filters.MediumTest
 import org.junit.Assert
 import org.junit.Before
-import org.junit.Ignore
 import org.junit.Test
 import org.junit.runner.RunWith
 import java.util.concurrent.CountDownLatch
@@ -129,7 +128,6 @@
         }
     }
 
-    @Ignore("b/174030267")
     @Test
     @LargeTest
     fun testDisposeWithComposeTables() {