Remove dependency on ViewModels in enableSavedStateHandles()

Rather than requiring the creation of a ViewModel as part of
the call to enableSavedStateHandle(), defer that creation
until createSavedStateHandle() is actually called. This
ensures that enableSavedStateHandles() is safe to call before
the owner is attached to a Context, etc.

Relnote: "You can now retrieve a previously registered
`SavedStateProvider` from a `SavedStateRegistry` via
`getSavedStateProvider()`."
Test: updated tests
BUG: 215406268

Change-Id: I7ea470c1af0385b8bc9d8ca653c84cc8d224e6cf
diff --git a/lifecycle/lifecycle-viewmodel-savedstate/build.gradle b/lifecycle/lifecycle-viewmodel-savedstate/build.gradle
index 5d0a4f8..19e6693 100644
--- a/lifecycle/lifecycle-viewmodel-savedstate/build.gradle
+++ b/lifecycle/lifecycle-viewmodel-savedstate/build.gradle
@@ -34,7 +34,7 @@
 dependencies {
     api("androidx.annotation:annotation:1.0.0")
     api("androidx.core:core-ktx:1.2.0")
-    api("androidx.savedstate:savedstate:1.0.0")
+    api(projectOrArtifact(":savedstate:savedstate"))
     api(projectOrArtifact(":lifecycle:lifecycle-livedata-core"))
     api(projectOrArtifact(":lifecycle:lifecycle-viewmodel"))
     api(libs.kotlinStdlib)
diff --git a/lifecycle/lifecycle-viewmodel-savedstate/src/androidTest/java/androidx/lifecycle/viewmodel/savedstate/SavedStateHandleSupportTest.kt b/lifecycle/lifecycle-viewmodel-savedstate/src/androidTest/java/androidx/lifecycle/viewmodel/savedstate/SavedStateHandleSupportTest.kt
index 04426af..12eff00 100644
--- a/lifecycle/lifecycle-viewmodel-savedstate/src/androidTest/java/androidx/lifecycle/viewmodel/savedstate/SavedStateHandleSupportTest.kt
+++ b/lifecycle/lifecycle-viewmodel-savedstate/src/androidTest/java/androidx/lifecycle/viewmodel/savedstate/SavedStateHandleSupportTest.kt
@@ -55,6 +55,7 @@
         component.resume()
         handle.set("a", "1")
         val interim = component.recreate(keepingViewModels = true)
+        interim.enableSavedStateHandles()
         handle.set("b", "2")
         interim.resume()
 
diff --git a/lifecycle/lifecycle-viewmodel-savedstate/src/main/java/androidx/lifecycle/SavedStateHandleSupport.kt b/lifecycle/lifecycle-viewmodel-savedstate/src/main/java/androidx/lifecycle/SavedStateHandleSupport.kt
index 9e69f6f..47151f7 100644
--- a/lifecycle/lifecycle-viewmodel-savedstate/src/main/java/androidx/lifecycle/SavedStateHandleSupport.kt
+++ b/lifecycle/lifecycle-viewmodel-savedstate/src/main/java/androidx/lifecycle/SavedStateHandleSupport.kt
@@ -22,10 +22,13 @@
 import androidx.annotation.MainThread
 import androidx.lifecycle.ViewModelProvider.NewInstanceFactory.Companion.VIEW_MODEL_KEY
 import androidx.lifecycle.viewmodel.CreationExtras
+import androidx.lifecycle.viewmodel.initializer
+import androidx.lifecycle.viewmodel.viewModelFactory
 import androidx.savedstate.SavedStateRegistry
 import androidx.savedstate.SavedStateRegistryOwner
 
 private const val VIEWMODEL_KEY = "androidx.lifecycle.internal.SavedStateHandlesVM"
+private const val SAVED_STATE_KEY = "androidx.lifecycle.internal.SavedStateHandlesProvider"
 
 /**
  * Enables the support of [SavedStateHandle] in a component.
@@ -44,15 +47,13 @@
         currentState == Lifecycle.State.INITIALIZED || currentState == Lifecycle.State.CREATED
     )
 
-    // make sure that SavedStateHandlesVM is created.
-    ViewModelProvider(this, object : ViewModelProvider.Factory {
-        override fun <T : ViewModel> create(modelClass: Class<T>): T {
-            @Suppress("UNCHECKED_CAST")
-            return SavedStateHandlesVM() as T
-        }
-    })[VIEWMODEL_KEY, SavedStateHandlesVM::class.java]
-
-    savedStateRegistry.runOnNextRecreation(SavedStateHandleAttacher::class.java)
+    // Add the SavedStateProvider used to save SavedStateHandles
+    // if we haven't already registered the provider
+    if (savedStateRegistry.getSavedStateProvider(SAVED_STATE_KEY) == null) {
+        val provider = SavedStateHandlesProvider(savedStateRegistry, this)
+        savedStateRegistry.registerSavedStateProvider(SAVED_STATE_KEY, provider)
+        savedStateRegistry.runOnNextRecreation(SavedStateHandleAttacher::class.java)
+    }
 }
 
 private fun createSavedStateHandle(
@@ -61,22 +62,20 @@
     key: String,
     defaultArgs: Bundle?
 ): SavedStateHandle {
-    val vm = viewModelStoreOwner.savedStateHandlesVM
-    val savedStateRegistry = savedStateRegistryOwner.savedStateRegistry
-    val handle = SavedStateHandle.createHandle(
-        savedStateRegistry.consumeRestoredStateForKey(key), defaultArgs
-    )
-    val controller = SavedStateHandleController(key, handle)
-    controller.attachToLifecycle(savedStateRegistry, savedStateRegistryOwner.lifecycle)
-    vm.controllers.add(controller)
-
-    return handle
+    val provider = savedStateRegistryOwner.savedStateHandlesProvider
+    val viewModel = viewModelStoreOwner.savedStateHandlesVM
+    // If we already have a reference to a previously created SavedStateHandle
+    // for a given key stored in our ViewModel, use that. Otherwise, create
+    // a new SavedStateHandle, providing it any restored state we might have saved
+    return viewModel.handles[key] ?: SavedStateHandle.createHandle(
+        provider.consumeRestoredStateForKey(key), defaultArgs
+    ).also { viewModel.handles[key] = it }
 }
 
 /**
  * Creates `SavedStateHandle` that can be used in your ViewModels
  *
- * This function requires `this.installSavedStateHandleSupport()` call during the component
+ * This function requires [enableSavedStateHandles] call during the component
  * initialization. Latest versions of androidx components like `ComponentActivity`, `Fragment`,
  * `NavBackStackEntry` makes this call automatically.
  *
@@ -106,39 +105,94 @@
     )
 }
 
-internal object ThrowingFactory : ViewModelProvider.Factory {
-    override fun <T : ViewModel> create(modelClass: Class<T>): T {
-        throw IllegalStateException(
-            "enableSavedStateHandles() wasn't called " +
-                "prior to createSavedStateHandle() call"
-        )
-    }
-}
-
 internal val ViewModelStoreOwner.savedStateHandlesVM: SavedStateHandlesVM
-    get() =
-        ViewModelProvider(this, ThrowingFactory)[VIEWMODEL_KEY, SavedStateHandlesVM::class.java]
+    get() = ViewModelProvider(this, viewModelFactory {
+        initializer { SavedStateHandlesVM() }
+    })[VIEWMODEL_KEY, SavedStateHandlesVM::class.java]
+
+internal val SavedStateRegistryOwner.savedStateHandlesProvider: SavedStateHandlesProvider
+    get() = savedStateRegistry.getSavedStateProvider(SAVED_STATE_KEY) as? SavedStateHandlesProvider
+        ?: throw IllegalStateException("enableSavedStateHandles() wasn't called " +
+            "prior to createSavedStateHandle() call")
 
 internal class SavedStateHandlesVM : ViewModel() {
-    val controllers = mutableListOf<SavedStateHandleController>()
+    val handles = mutableMapOf<String, SavedStateHandle>()
+}
+
+/**
+ * This single SavedStateProvider is responsible for saving the state of every
+ * SavedStateHandle associated with the SavedState/ViewModel pair.
+ */
+internal class SavedStateHandlesProvider(
+    private val savedStateRegistry: SavedStateRegistry,
+    viewModelStoreOwner: ViewModelStoreOwner
+) : SavedStateRegistry.SavedStateProvider {
+    private var restored = false
+    private var restoredState: Bundle? = null
+
+    private val viewModel by lazy {
+        viewModelStoreOwner.savedStateHandlesVM
+    }
+
+    override fun saveState(): Bundle {
+        return Bundle().apply {
+            // Ensure that even if ViewModels aren't recreated after process death and recreation
+            // that we keep their state until they are recreated
+            if (restoredState != null) {
+                putAll(restoredState)
+            }
+            // But if we do have ViewModels, prefer their state over what we may
+            // have restored
+            viewModel.handles.forEach { (key, handle) ->
+                val savedState = handle.savedStateProvider().saveState()
+                if (savedState != Bundle.EMPTY) {
+                    putBundle(key, savedState)
+                }
+            }
+        }.also {
+            // After we've saved the state, allow restoring a second time
+            restored = false
+        }
+    }
+
+    /**
+     * Restore the state from the SavedStateRegistry if it hasn't already been restored.
+     */
+    fun performRestore() {
+        if (!restored) {
+            restoredState = savedStateRegistry.consumeRestoredStateForKey(SAVED_STATE_KEY)
+            restored = true
+            // Grab a reference to the ViewModel for later usage when we saveState()
+            // This ensures that even if saveState() is called after the Lifecycle is
+            // DESTROYED, we can still save the state
+            viewModel
+        }
+    }
+
+    /**
+     * Restore the state associated with a particular SavedStateHandle, identified by its [key]
+     */
+    fun consumeRestoredStateForKey(key: String): Bundle? {
+        performRestore()
+        return restoredState?.getBundle(key).also {
+            restoredState?.remove(key)
+            if (restoredState?.isEmpty == true) {
+                restoredState = null
+            }
+        }
+    }
 }
 
 // it reconnects existent SavedStateHandles to SavedStateRegistryOwner when it is recreated
 internal class SavedStateHandleAttacher : SavedStateRegistry.AutoRecreated {
     override fun onRecreated(owner: SavedStateRegistryOwner) {
-        if (owner !is ViewModelStoreOwner) {
-            throw java.lang.IllegalStateException(
-                "Internal error: SavedStateHandleAttacher should be registered only on components" +
-                    "that implement ViewModelStoreOwner"
-            )
-        }
-        val viewModelStore = (owner as ViewModelStoreOwner).viewModelStore
-        // if savedStateHandlesVM wasn't created previously, we shouldn't trigger a creation of it
-        if (!viewModelStore.keys().contains(VIEWMODEL_KEY)) return
-        owner.savedStateHandlesVM.controllers.forEach {
-            it.attachToLifecycle(owner.savedStateRegistry, owner.lifecycle)
-        }
-        owner.savedStateRegistry.runOnNextRecreation(SavedStateHandleAttacher::class.java)
+        // if SavedStateHandlesProvider wasn't added previously, there's nothing for us to do
+        val provider = owner.savedStateRegistry
+            .getSavedStateProvider(SAVED_STATE_KEY) as? SavedStateHandlesProvider ?: return
+        // onRecreated() is called after the Lifecycle reaches CREATED, so we
+        // eagerly restore the state as part of this call to ensure it consumed
+        // even if no ViewModels are actually created during this cycle of the Lifecycle
+        provider.performRestore()
     }
 }
 
diff --git a/savedstate/savedstate/api/current.txt b/savedstate/savedstate/api/current.txt
index 71beea9..2d04ca6 100644
--- a/savedstate/savedstate/api/current.txt
+++ b/savedstate/savedstate/api/current.txt
@@ -3,6 +3,7 @@
 
   public final class SavedStateRegistry {
     method @MainThread public android.os.Bundle? consumeRestoredStateForKey(String);
+    method public androidx.savedstate.SavedStateRegistry.SavedStateProvider? getSavedStateProvider(String);
     method @MainThread public boolean isRestored();
     method @MainThread public void registerSavedStateProvider(String, androidx.savedstate.SavedStateRegistry.SavedStateProvider);
     method @MainThread public void runOnNextRecreation(Class<? extends androidx.savedstate.SavedStateRegistry.AutoRecreated>);
diff --git a/savedstate/savedstate/api/public_plus_experimental_current.txt b/savedstate/savedstate/api/public_plus_experimental_current.txt
index 71beea9..2d04ca6 100644
--- a/savedstate/savedstate/api/public_plus_experimental_current.txt
+++ b/savedstate/savedstate/api/public_plus_experimental_current.txt
@@ -3,6 +3,7 @@
 
   public final class SavedStateRegistry {
     method @MainThread public android.os.Bundle? consumeRestoredStateForKey(String);
+    method public androidx.savedstate.SavedStateRegistry.SavedStateProvider? getSavedStateProvider(String);
     method @MainThread public boolean isRestored();
     method @MainThread public void registerSavedStateProvider(String, androidx.savedstate.SavedStateRegistry.SavedStateProvider);
     method @MainThread public void runOnNextRecreation(Class<? extends androidx.savedstate.SavedStateRegistry.AutoRecreated>);
diff --git a/savedstate/savedstate/api/restricted_current.txt b/savedstate/savedstate/api/restricted_current.txt
index 71beea9..2d04ca6 100644
--- a/savedstate/savedstate/api/restricted_current.txt
+++ b/savedstate/savedstate/api/restricted_current.txt
@@ -3,6 +3,7 @@
 
   public final class SavedStateRegistry {
     method @MainThread public android.os.Bundle? consumeRestoredStateForKey(String);
+    method public androidx.savedstate.SavedStateRegistry.SavedStateProvider? getSavedStateProvider(String);
     method @MainThread public boolean isRestored();
     method @MainThread public void registerSavedStateProvider(String, androidx.savedstate.SavedStateRegistry.SavedStateProvider);
     method @MainThread public void runOnNextRecreation(Class<? extends androidx.savedstate.SavedStateRegistry.AutoRecreated>);
diff --git a/savedstate/savedstate/src/main/java/androidx/savedstate/SavedStateRegistry.java b/savedstate/savedstate/src/main/java/androidx/savedstate/SavedStateRegistry.java
index adeae54..3bca4f1 100644
--- a/savedstate/savedstate/src/main/java/androidx/savedstate/SavedStateRegistry.java
+++ b/savedstate/savedstate/src/main/java/androidx/savedstate/SavedStateRegistry.java
@@ -114,6 +114,28 @@
     }
 
     /**
+     * Get a previously registered {@link SavedStateProvider}.
+     *
+     * @param key The key used to register the {@link SavedStateProvider} when it was registered
+     *            with {@link #registerSavedStateProvider(String, SavedStateProvider)}.
+     * @return The {@link SavedStateProvider} previously registered with
+     * {@link #registerSavedStateProvider(String, SavedStateProvider)} or null if no provider
+     * has been registered with the given key.
+     * @see #registerSavedStateProvider(String, SavedStateProvider)
+     */
+    @Nullable
+    public SavedStateProvider getSavedStateProvider(@NonNull String key) {
+        SavedStateProvider provider = null;
+        for (Map.Entry<String, SavedStateProvider> entry : mComponents) {
+            if (entry.getKey().equals(key)) {
+                provider = entry.getValue();
+                break;
+            }
+        }
+        return provider;
+    }
+
+    /**
      * Unregisters a component previously registered by the given {@code key}
      *
      * @param key a key with which a component was previously registered.