diff --git a/demo/registration/src/main/java/org/signal/registration/sample/debug/NetworkDebugOverlay.kt b/demo/registration/src/main/java/org/signal/registration/sample/debug/NetworkDebugOverlay.kt index e48423e359..a9549b9ac0 100644 --- a/demo/registration/src/main/java/org/signal/registration/sample/debug/NetworkDebugOverlay.kt +++ b/demo/registration/src/main/java/org/signal/registration/sample/debug/NetworkDebugOverlay.kt @@ -69,7 +69,7 @@ fun NetworkDebugOverlay( onClick = { showDialog = true }, dragOffset = dragOffset, onDrag = { delta -> dragOffset += delta }, - modifier = Modifier.align(Alignment.BottomEnd) + modifier = Modifier.align(Alignment.CenterEnd) ) if (showDialog) { diff --git a/feature/registration/src/main/java/org/signal/registration/screens/phonenumber/PhoneNumberEntryScreen.kt b/feature/registration/src/main/java/org/signal/registration/screens/phonenumber/PhoneNumberEntryScreen.kt index 0cf6f9198a..aa8cdf7d99 100644 --- a/feature/registration/src/main/java/org/signal/registration/screens/phonenumber/PhoneNumberEntryScreen.kt +++ b/feature/registration/src/main/java/org/signal/registration/screens/phonenumber/PhoneNumberEntryScreen.kt @@ -5,6 +5,9 @@ package org.signal.registration.screens.phonenumber +import androidx.compose.foundation.Canvas +import androidx.compose.foundation.background +import androidx.compose.foundation.clickable import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Column @@ -14,13 +17,14 @@ import androidx.compose.foundation.layout.fillMaxSize import androidx.compose.foundation.layout.fillMaxWidth import androidx.compose.foundation.layout.height import androidx.compose.foundation.layout.padding +import androidx.compose.foundation.layout.size import androidx.compose.foundation.layout.width +import androidx.compose.foundation.rememberScrollState +import androidx.compose.foundation.shape.RoundedCornerShape import androidx.compose.foundation.text.KeyboardActions import androidx.compose.foundation.text.KeyboardOptions -import androidx.compose.material3.Button -import androidx.compose.material3.Icon +import androidx.compose.foundation.verticalScroll import androidx.compose.material3.MaterialTheme -import androidx.compose.material3.OutlinedButton import androidx.compose.material3.OutlinedTextField import androidx.compose.material3.Text import androidx.compose.runtime.Composable @@ -31,16 +35,23 @@ import androidx.compose.runtime.remember import androidx.compose.runtime.setValue import androidx.compose.ui.Alignment import androidx.compose.ui.Modifier +import androidx.compose.ui.draw.clip +import androidx.compose.ui.graphics.Color +import androidx.compose.ui.graphics.Path import androidx.compose.ui.platform.testTag -import androidx.compose.ui.res.painterResource +import androidx.compose.ui.res.stringResource import androidx.compose.ui.text.TextRange import androidx.compose.ui.text.input.ImeAction import androidx.compose.ui.text.input.KeyboardType import androidx.compose.ui.text.input.TextFieldValue import androidx.compose.ui.unit.dp +import androidx.compose.ui.unit.sp +import org.signal.core.ui.compose.Buttons +import org.signal.core.ui.compose.CircularProgressWrapper import org.signal.core.ui.compose.DayNightPreviews import org.signal.core.ui.compose.Dialogs import org.signal.core.ui.compose.Previews +import org.signal.registration.R import org.signal.registration.screens.phonenumber.PhoneNumberEntryState.OneTimeEvent import org.signal.registration.test.TestTags @@ -78,10 +89,6 @@ fun PhoneNumberScreen( Box(modifier = modifier.fillMaxSize().testTag(TestTags.PHONE_NUMBER_SCREEN)) { ScreenContent(state, onEvent) - - if (state.showFullScreenSpinner) { - Dialogs.IndeterminateProgressDialog() - } } } @@ -90,23 +97,170 @@ private fun ScreenContent(state: PhoneNumberEntryState, onEvent: (PhoneNumberEnt val selectedCountry = state.countryName val selectedCountryEmoji = state.countryEmoji - // Track the phone number text field value with cursor position - var phoneNumberTextFieldValue by remember { mutableStateOf(TextFieldValue(state.formattedNumber)) } + val scrollState = rememberScrollState() - // Update the text field value when state.formattedNumber changes, preserving cursor position - LaunchedEffect(state.formattedNumber) { - if (phoneNumberTextFieldValue.text != state.formattedNumber) { + Column( + modifier = Modifier + .fillMaxSize() + .verticalScroll(scrollState) + ) { + // Toolbar spacer (matching the Toolbar height in the XML) + Spacer(modifier = Modifier.height(56.dp)) + + // Title - "Phone number" + Text( + text = stringResource(R.string.RegistrationActivity_phone_number), + style = MaterialTheme.typography.headlineMedium, + modifier = Modifier + .fillMaxWidth() + .padding(horizontal = 24.dp) + ) + + Spacer(modifier = Modifier.height(16.dp)) + + // Subtitle - "You will receive a verification code..." + Text( + text = stringResource(R.string.RegistrationActivity_you_will_receive_a_verification_code), + style = MaterialTheme.typography.bodyLarge, + color = MaterialTheme.colorScheme.onSurfaceVariant, + modifier = Modifier + .fillMaxWidth() + .padding(horizontal = 24.dp) + ) + + Spacer(modifier = Modifier.height(36.dp)) + + // Country Picker - styled with surfaceVariant background and outline bottom border + CountryPicker( + emoji = selectedCountryEmoji, + country = selectedCountry, + onClick = { onEvent(PhoneNumberEntryScreenEvents.CountryPicker) }, + modifier = Modifier + .fillMaxWidth() + .padding(horizontal = 24.dp) + .testTag(TestTags.PHONE_NUMBER_COUNTRY_PICKER) + ) + + Spacer(modifier = Modifier.height(16.dp)) + + PhoneNumberInputFields( + countryCode = state.countryCode, + formattedNumber = state.formattedNumber, + onCountryCodeChanged = { onEvent(PhoneNumberEntryScreenEvents.CountryCodeChanged(it)) }, + onPhoneNumberChanged = { onEvent(PhoneNumberEntryScreenEvents.PhoneNumberChanged(it)) }, + onPhoneNumberSubmitted = { onEvent(PhoneNumberEntryScreenEvents.PhoneNumberSubmitted) }, + modifier = Modifier + .fillMaxWidth() + .padding(horizontal = 24.dp) + ) + + Spacer(modifier = Modifier.weight(1f)) + + // Bottom row with the next/spinner button aligned to end + Row( + modifier = Modifier + .fillMaxWidth() + .padding(horizontal = 32.dp, vertical = 16.dp), + horizontalArrangement = Arrangement.End, + verticalAlignment = Alignment.CenterVertically + ) { + CircularProgressWrapper( + isLoading = state.showSpinner + ) { + Buttons.LargeTonal( + onClick = { onEvent(PhoneNumberEntryScreenEvents.PhoneNumberSubmitted) }, + enabled = state.countryCode.isNotEmpty() && state.nationalNumber.isNotEmpty(), + modifier = Modifier.testTag(TestTags.PHONE_NUMBER_NEXT_BUTTON) + ) { + Text(stringResource(R.string.RegistrationActivity_next)) + } + } + } + } +} + +/** + * Country picker row styled to match the XML layout: + * - surfaceVariant background with outline bottom border + * - Rounded top corners (8dp outline, 4dp inner) + * - Country emoji, country name, and dropdown triangle + */ +@Composable +private fun CountryPicker( + emoji: String, + country: String, + onClick: () -> Unit, + modifier: Modifier = Modifier +) { + Box( + modifier = modifier + .clip(RoundedCornerShape(topStart = 8.dp, topEnd = 8.dp)) + .background(MaterialTheme.colorScheme.outline) + .padding(bottom = 1.dp) + .background( + color = MaterialTheme.colorScheme.surfaceVariant, + shape = RoundedCornerShape(topStart = 4.dp, topEnd = 4.dp) + ) + .clickable(onClick = onClick) + .height(56.dp) + ) { + Row( + modifier = Modifier + .fillMaxSize() + .padding(start = 16.dp, end = 12.dp), + verticalAlignment = Alignment.CenterVertically + ) { + Text( + text = emoji, + fontSize = 24.sp + ) + + Spacer(modifier = Modifier.width(16.dp)) + + Text( + text = country, + style = MaterialTheme.typography.bodyLarge, + color = MaterialTheme.colorScheme.onSurfaceVariant, + modifier = Modifier.weight(1f) + ) + + DropdownTriangle( + tint = MaterialTheme.colorScheme.onSurfaceVariant, + modifier = Modifier.size(24.dp) + ) + } + } +} + +/** + * Phone number input fields containing the country code and phone number text fields. + * Handles cursor position preservation when the formatted number changes. + */ +@Composable +private fun PhoneNumberInputFields( + countryCode: String, + formattedNumber: String, + onCountryCodeChanged: (String) -> Unit, + onPhoneNumberChanged: (String) -> Unit, + onPhoneNumberSubmitted: () -> Unit, + modifier: Modifier = Modifier +) { + // Track the phone number text field value with cursor position + var phoneNumberTextFieldValue by remember { mutableStateOf(TextFieldValue(formattedNumber)) } + + // Update the text field value when formattedNumber changes, preserving cursor position + LaunchedEffect(formattedNumber) { + if (phoneNumberTextFieldValue.text != formattedNumber) { // Calculate cursor position: count digits before cursor in old text, // then find position with same digit count in new text val oldText = phoneNumberTextFieldValue.text val oldCursorPos = phoneNumberTextFieldValue.selection.end val digitsBeforeCursor = oldText.take(oldCursorPos).count { it.isDigit() } - val newText = state.formattedNumber var digitCount = 0 - var newCursorPos = newText.length - for (i in newText.indices) { - if (newText[i].isDigit()) { + var newCursorPos = formattedNumber.length + for (i in formattedNumber.indices) { + if (formattedNumber[i].isDigit()) { digitCount++ } if (digitCount >= digitsBeforeCursor) { @@ -116,140 +270,92 @@ private fun ScreenContent(state: PhoneNumberEntryState, onEvent: (PhoneNumberEnt } phoneNumberTextFieldValue = TextFieldValue( - text = newText, + text = formattedNumber, selection = TextRange(newCursorPos) ) } } - Column( - modifier = Modifier - .fillMaxSize() - .padding(24.dp), - horizontalAlignment = Alignment.Start + Row( + modifier = modifier, + horizontalArrangement = Arrangement.Start, + verticalAlignment = Alignment.Bottom ) { - // Title - Text( - text = "Phone number", - style = MaterialTheme.typography.headlineMedium, - modifier = Modifier.fillMaxWidth() - ) - - Spacer(modifier = Modifier.height(16.dp)) - - // Subtitle - Text( - text = "You will receive a verification code", - style = MaterialTheme.typography.bodyLarge, - color = MaterialTheme.colorScheme.onSurfaceVariant, - modifier = Modifier.fillMaxWidth() - ) - - Spacer(modifier = Modifier.height(36.dp)) - - // Country Picker Button - OutlinedButton( - onClick = { - onEvent(PhoneNumberEntryScreenEvents.CountryPicker) - }, + // Country code field + OutlinedTextField( + value = countryCode, + onValueChange = onCountryCodeChanged, modifier = Modifier - .fillMaxWidth() - .height(56.dp) - .testTag(TestTags.PHONE_NUMBER_COUNTRY_PICKER) - ) { - Row( - modifier = Modifier.fillMaxWidth(), - horizontalArrangement = Arrangement.Start, - verticalAlignment = Alignment.CenterVertically - ) { + .width(76.dp) + .testTag(TestTags.PHONE_NUMBER_COUNTRY_CODE_FIELD), + prefix = { Text( - text = selectedCountryEmoji, - style = MaterialTheme.typography.headlineSmall - ) - Spacer(modifier = Modifier.width(16.dp)) - Text( - text = selectedCountry, + text = "+", style = MaterialTheme.typography.bodyLarge, - color = MaterialTheme.colorScheme.onSurfaceVariant, - modifier = Modifier.weight(1f) + color = MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.6f) ) - Icon( - painter = painterResource(android.R.drawable.arrow_down_float), - contentDescription = "Select country", - tint = MaterialTheme.colorScheme.onSurfaceVariant - ) - } - } - - Spacer(modifier = Modifier.height(20.dp)) - - // Phone number input fields - Row( - modifier = Modifier.fillMaxWidth(), - horizontalArrangement = Arrangement.spacedBy(20.dp) - ) { - // Country code field - OutlinedTextField( - value = state.countryCode, - onValueChange = { onEvent(PhoneNumberEntryScreenEvents.CountryCodeChanged(it)) }, - modifier = Modifier - .width(76.dp) - .testTag(TestTags.PHONE_NUMBER_COUNTRY_CODE_FIELD), - leadingIcon = { - Text( - text = "+", - style = MaterialTheme.typography.bodyLarge, - color = MaterialTheme.colorScheme.onSurfaceVariant - ) - }, - keyboardOptions = KeyboardOptions( - keyboardType = KeyboardType.Number, - imeAction = ImeAction.Next - ), - singleLine = true + }, + keyboardOptions = KeyboardOptions( + keyboardType = KeyboardType.Number, + imeAction = ImeAction.Done + ), + singleLine = true, + textStyle = MaterialTheme.typography.bodyLarge.copy( + color = MaterialTheme.colorScheme.onSurfaceVariant ) + ) - // Phone number field - OutlinedTextField( - value = phoneNumberTextFieldValue, - onValueChange = { newValue -> - phoneNumberTextFieldValue = newValue - onEvent(PhoneNumberEntryScreenEvents.PhoneNumberChanged(newValue.text)) - }, - modifier = Modifier - .weight(1f) - .testTag(TestTags.PHONE_NUMBER_PHONE_FIELD), - placeholder = { - Text("Phone number") - }, - keyboardOptions = KeyboardOptions( - keyboardType = KeyboardType.Phone, - imeAction = ImeAction.Done - ), - keyboardActions = KeyboardActions( - onDone = { - onEvent(PhoneNumberEntryScreenEvents.PhoneNumberSubmitted) - } - ), - singleLine = true - ) - } + Spacer(modifier = Modifier.width(20.dp)) - Spacer(modifier = Modifier.weight(1f)) - - // Next button - Button( - onClick = { - onEvent(PhoneNumberEntryScreenEvents.PhoneNumberSubmitted) + // Phone number field + OutlinedTextField( + value = phoneNumberTextFieldValue, + onValueChange = { newValue -> + phoneNumberTextFieldValue = newValue + onPhoneNumberChanged(newValue.text) }, modifier = Modifier - .fillMaxWidth() - .height(56.dp) - .testTag(TestTags.PHONE_NUMBER_NEXT_BUTTON), - enabled = state.countryCode.isNotEmpty() && state.nationalNumber.isNotEmpty() - ) { - Text("Next") + .weight(1f) + .testTag(TestTags.PHONE_NUMBER_PHONE_FIELD), + label = { + Text(stringResource(R.string.RegistrationActivity_phone_number_description)) + }, + keyboardOptions = KeyboardOptions( + keyboardType = KeyboardType.Phone, + imeAction = ImeAction.Done + ), + keyboardActions = KeyboardActions( + onDone = { onPhoneNumberSubmitted() } + ), + singleLine = true, + textStyle = MaterialTheme.typography.bodyLarge.copy( + color = MaterialTheme.colorScheme.onSurface + ) + ) + } +} + +/** + * Simple dropdown triangle icon matching the symbol_dropdown_triangle_24 vector drawable. + */ +@Composable +private fun DropdownTriangle( + tint: Color, + modifier: Modifier = Modifier +) { + Canvas(modifier = modifier) { + val w = size.width + val h = size.height + val path = Path().apply { + // Triangle pointing down, centered in the 18x24 viewport + val scaleX = w / 18f + val scaleY = h / 24f + moveTo(5.2f * scaleX, 9.5f * scaleY) + lineTo(12.8f * scaleX, 9.5f * scaleY) + lineTo(9f * scaleX, 14.95f * scaleY) + close() } + drawPath(path, tint) } } @@ -269,7 +375,7 @@ private fun PhoneNumberScreenPreview() { private fun PhoneNumberScreenSpinnerPreview() { Previews.Preview { PhoneNumberScreen( - state = PhoneNumberEntryState(showFullScreenSpinner = true), + state = PhoneNumberEntryState(showSpinner = true), onEvent = {} ) } diff --git a/feature/registration/src/main/java/org/signal/registration/screens/phonenumber/PhoneNumberEntryState.kt b/feature/registration/src/main/java/org/signal/registration/screens/phonenumber/PhoneNumberEntryState.kt index aca0733b0e..3a7e30821c 100644 --- a/feature/registration/src/main/java/org/signal/registration/screens/phonenumber/PhoneNumberEntryState.kt +++ b/feature/registration/src/main/java/org/signal/registration/screens/phonenumber/PhoneNumberEntryState.kt @@ -19,7 +19,7 @@ data class PhoneNumberEntryState( val formattedNumber: String = "", val sessionE164: String? = null, val sessionMetadata: SessionMetadata? = null, - val showFullScreenSpinner: Boolean = false, + val showSpinner: Boolean = false, val oneTimeEvent: OneTimeEvent? = null, val preExistingRegistrationData: PreExistingRegistrationData? = null, val restoredSvrCredentials: List = emptyList() diff --git a/feature/registration/src/main/java/org/signal/registration/screens/phonenumber/PhoneNumberEntryViewModel.kt b/feature/registration/src/main/java/org/signal/registration/screens/phonenumber/PhoneNumberEntryViewModel.kt index cb4e4e98c9..368130cac1 100644 --- a/feature/registration/src/main/java/org/signal/registration/screens/phonenumber/PhoneNumberEntryViewModel.kt +++ b/feature/registration/src/main/java/org/signal/registration/screens/phonenumber/PhoneNumberEntryViewModel.kt @@ -46,7 +46,7 @@ class PhoneNumberEntryViewModel( val state = _state .combine(parentState) { state, parentState -> applyParentState(state, parentState) } .onEach { Log.d(TAG, "[State] $it") } - .stateIn(viewModelScope, SharingStarted.Eagerly, PhoneNumberEntryState()) + .stateIn(viewModelScope, SharingStarted.WhileSubscribed(5000), PhoneNumberEntryState()) init { viewModelScope.launch { @@ -78,10 +78,10 @@ class PhoneNumberEntryViewModel( stateEmitter(applyPhoneNumberChanged(state, event.value)) } is PhoneNumberEntryScreenEvents.PhoneNumberSubmitted -> { - var localState = state.copy(showFullScreenSpinner = true) + var localState = state.copy(showSpinner = true) stateEmitter(localState) localState = applyPhoneNumberSubmitted(localState, parentEventEmitter) - stateEmitter(localState.copy(showFullScreenSpinner = false)) + stateEmitter(localState.copy(showSpinner = false)) } is PhoneNumberEntryScreenEvents.CountryPicker -> { state.also { parentEventEmitter.navigateTo(RegistrationRoute.CountryCodePicker) } diff --git a/feature/registration/src/main/java/org/signal/registration/screens/pincreation/PinCreationViewModel.kt b/feature/registration/src/main/java/org/signal/registration/screens/pincreation/PinCreationViewModel.kt index 1f315d7d5a..85b5c49d09 100644 --- a/feature/registration/src/main/java/org/signal/registration/screens/pincreation/PinCreationViewModel.kt +++ b/feature/registration/src/main/java/org/signal/registration/screens/pincreation/PinCreationViewModel.kt @@ -48,7 +48,7 @@ class PinCreationViewModel( val state: StateFlow = _state .combine(parentState) { state, parentState -> applyParentState(state, parentState) } .onEach { Log.d(TAG, "[State] $it") } - .stateIn(viewModelScope, SharingStarted.Eagerly, PinCreationState(inputLabel = "PIN must be at least 4 digits")) + .stateIn(viewModelScope, SharingStarted.WhileSubscribed(5000), PinCreationState(inputLabel = "PIN must be at least 4 digits")) fun onEvent(event: PinCreationScreenEvents) { viewModelScope.launch { diff --git a/feature/registration/src/main/java/org/signal/registration/screens/pinentry/PinEntryForRegistrationLockViewModel.kt b/feature/registration/src/main/java/org/signal/registration/screens/pinentry/PinEntryForRegistrationLockViewModel.kt index 7df112e0c8..86e60250a5 100644 --- a/feature/registration/src/main/java/org/signal/registration/screens/pinentry/PinEntryForRegistrationLockViewModel.kt +++ b/feature/registration/src/main/java/org/signal/registration/screens/pinentry/PinEntryForRegistrationLockViewModel.kt @@ -50,7 +50,7 @@ class PinEntryForRegistrationLockViewModel( val state: StateFlow = _state .onEach { Log.d(TAG, "[State] $it") } - .stateIn(viewModelScope, SharingStarted.Eagerly, PinEntryState(showNeedHelp = true)) + .stateIn(viewModelScope, SharingStarted.WhileSubscribed(5000), PinEntryState(showNeedHelp = true)) fun onEvent(event: PinEntryScreenEvents) { viewModelScope.launch { diff --git a/feature/registration/src/main/java/org/signal/registration/screens/pinentry/PinEntryForSmsBypassViewModel.kt b/feature/registration/src/main/java/org/signal/registration/screens/pinentry/PinEntryForSmsBypassViewModel.kt index 41dc1143f0..3bab0d2655 100644 --- a/feature/registration/src/main/java/org/signal/registration/screens/pinentry/PinEntryForSmsBypassViewModel.kt +++ b/feature/registration/src/main/java/org/signal/registration/screens/pinentry/PinEntryForSmsBypassViewModel.kt @@ -54,7 +54,7 @@ class PinEntryForSmsBypassViewModel( val state: StateFlow = _state .combine(parentState) { state, parentState -> applyParentState(state, parentState) } .onEach { Log.d(TAG, "[State] $it") } - .stateIn(viewModelScope, SharingStarted.Eagerly, PinEntryState(showNeedHelp = true)) + .stateIn(viewModelScope, SharingStarted.WhileSubscribed(5000), PinEntryState(showNeedHelp = true)) fun onEvent(event: PinEntryScreenEvents) { viewModelScope.launch { diff --git a/feature/registration/src/main/java/org/signal/registration/screens/pinentry/PinEntryForSvrRestoreViewModel.kt b/feature/registration/src/main/java/org/signal/registration/screens/pinentry/PinEntryForSvrRestoreViewModel.kt index c03d7eb796..bf5862012c 100644 --- a/feature/registration/src/main/java/org/signal/registration/screens/pinentry/PinEntryForSvrRestoreViewModel.kt +++ b/feature/registration/src/main/java/org/signal/registration/screens/pinentry/PinEntryForSvrRestoreViewModel.kt @@ -47,7 +47,7 @@ class PinEntryForSvrRestoreViewModel( val state: StateFlow = _state .onEach { Log.d(TAG, "[State] $it") } - .stateIn(viewModelScope, SharingStarted.Eagerly, PinEntryState(showNeedHelp = true)) + .stateIn(viewModelScope, SharingStarted.WhileSubscribed(5000), PinEntryState(showNeedHelp = true)) fun onEvent(event: PinEntryScreenEvents) { viewModelScope.launch { diff --git a/feature/registration/src/main/java/org/signal/registration/screens/verificationcode/VerificationCodeScreen.kt b/feature/registration/src/main/java/org/signal/registration/screens/verificationcode/VerificationCodeScreen.kt index 3cce17f1e9..bb01b89d9d 100644 --- a/feature/registration/src/main/java/org/signal/registration/screens/verificationcode/VerificationCodeScreen.kt +++ b/feature/registration/src/main/java/org/signal/registration/screens/verificationcode/VerificationCodeScreen.kt @@ -6,6 +6,7 @@ package org.signal.registration.screens.verificationcode import androidx.compose.foundation.layout.Arrangement +import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.Row import androidx.compose.foundation.layout.Spacer @@ -13,10 +14,19 @@ import androidx.compose.foundation.layout.fillMaxSize import androidx.compose.foundation.layout.fillMaxWidth import androidx.compose.foundation.layout.height import androidx.compose.foundation.layout.padding +import androidx.compose.foundation.layout.size import androidx.compose.foundation.layout.width +import androidx.compose.foundation.layout.wrapContentWidth +import androidx.compose.foundation.rememberScrollState import androidx.compose.foundation.text.KeyboardOptions +import androidx.compose.foundation.verticalScroll +import androidx.compose.material3.CircularProgressIndicator import androidx.compose.material3.MaterialTheme import androidx.compose.material3.OutlinedTextField +import androidx.compose.material3.OutlinedTextFieldDefaults +import androidx.compose.material3.Scaffold +import androidx.compose.material3.SnackbarHost +import androidx.compose.material3.SnackbarHostState import androidx.compose.material3.Text import androidx.compose.material3.TextButton import androidx.compose.runtime.Composable @@ -29,17 +39,27 @@ import androidx.compose.ui.Alignment import androidx.compose.ui.Modifier import androidx.compose.ui.focus.FocusRequester import androidx.compose.ui.focus.focusRequester +import androidx.compose.ui.input.key.Key +import androidx.compose.ui.input.key.KeyEventType +import androidx.compose.ui.input.key.key +import androidx.compose.ui.input.key.onKeyEvent +import androidx.compose.ui.input.key.type import androidx.compose.ui.platform.testTag +import androidx.compose.ui.res.stringResource import androidx.compose.ui.text.input.KeyboardType import androidx.compose.ui.text.style.TextAlign import androidx.compose.ui.unit.dp +import kotlinx.coroutines.delay import org.signal.core.ui.compose.DayNightPreviews import org.signal.core.ui.compose.Previews +import org.signal.registration.R import org.signal.registration.test.TestTags +import kotlin.time.Duration.Companion.seconds /** * Verification code entry screen for the registration flow. - * Displays a 6-digit code input in XXX-XXX format. + * Displays a 6-digit code input in XXX-XXX format with countdown buttons + * for resend SMS and call me actions. */ @Composable fun VerificationCodeScreen( @@ -49,177 +69,372 @@ fun VerificationCodeScreen( ) { var digits by remember { mutableStateOf(List(6) { "" }) } val focusRequesters = remember { List(6) { FocusRequester() } } + val scrollState = rememberScrollState() + val snackbarHostState = remember { SnackbarHostState() } + + // Preload error strings for use in LaunchedEffect + val incorrectCodeMsg = stringResource(R.string.VerificationCodeScreen__incorrect_code) + val networkErrorMsg = stringResource(R.string.VerificationCodeScreen__network_error) + val unknownErrorMsg = stringResource(R.string.VerificationCodeScreen__an_unexpected_error_occurred) + val smsProviderErrorMsg = stringResource(R.string.VerificationCodeScreen__sms_provider_error) + val transportErrorMsg = stringResource(R.string.VerificationCodeScreen__could_not_send_code_via_selected_method) + val registrationErrorMsg = stringResource(R.string.VerificationCodeScreen__registration_error) + // Preformat the rate-limited message template + val rateLimitedEvent = state.oneTimeEvent as? VerificationCodeState.OneTimeEvent.RateLimited + val rateLimitedMsg = if (rateLimitedEvent != null) { + stringResource(R.string.VerificationCodeScreen__too_many_attempts_try_again_in_s, rateLimitedEvent.retryAfter.toString()) + } else { + "" + } + + // Countdown timer effect - emits CountdownTick every second while timers are active + LaunchedEffect(state.rateLimits) { + if (state.rateLimits.smsResendTimeRemaining > 0.seconds || state.rateLimits.callRequestTimeRemaining > 0.seconds) { + while (true) { + delay(1000) + onEvent(VerificationCodeScreenEvents.CountdownTick) + } + } + } // Auto-submit when all digits are entered LaunchedEffect(digits) { - if (digits.all { it.isNotEmpty() }) { + if (digits.all { it.isNotEmpty() } && !state.isSubmittingCode) { val code = digits.joinToString("") onEvent(VerificationCodeScreenEvents.CodeEntered(code)) } } + // Handle one-time events — handle first, then consume LaunchedEffect(state.oneTimeEvent) { + val event = state.oneTimeEvent ?: return@LaunchedEffect + + when (event) { + VerificationCodeState.OneTimeEvent.IncorrectVerificationCode -> { + digits = List(6) { "" } + focusRequesters[0].requestFocus() + snackbarHostState.showSnackbar(incorrectCodeMsg) + } + VerificationCodeState.OneTimeEvent.NetworkError -> { + snackbarHostState.showSnackbar(networkErrorMsg) + } + is VerificationCodeState.OneTimeEvent.RateLimited -> { + snackbarHostState.showSnackbar(rateLimitedMsg) + } + VerificationCodeState.OneTimeEvent.ThirdPartyError -> { + snackbarHostState.showSnackbar(smsProviderErrorMsg) + } + VerificationCodeState.OneTimeEvent.CouldNotRequestCodeWithSelectedTransport -> { + snackbarHostState.showSnackbar(transportErrorMsg) + } + VerificationCodeState.OneTimeEvent.UnknownError -> { + snackbarHostState.showSnackbar(unknownErrorMsg) + } + VerificationCodeState.OneTimeEvent.RegistrationError -> { + snackbarHostState.showSnackbar(registrationErrorMsg) + } + } + onEvent(VerificationCodeScreenEvents.ConsumeInnerOneTimeEvent) - - when (state.oneTimeEvent) { - VerificationCodeState.OneTimeEvent.CouldNotRequestCodeWithSelectedTransport -> { } - VerificationCodeState.OneTimeEvent.IncorrectVerificationCode -> { } - VerificationCodeState.OneTimeEvent.NetworkError -> { } - is VerificationCodeState.OneTimeEvent.RateLimited -> { } - VerificationCodeState.OneTimeEvent.ThirdPartyError -> { } - VerificationCodeState.OneTimeEvent.UnknownError -> { } - VerificationCodeState.OneTimeEvent.RegistrationError -> { } - null -> { } - } - } - - Column( - modifier = modifier - .fillMaxSize() - .padding(24.dp), - horizontalAlignment = Alignment.CenterHorizontally, - verticalArrangement = Arrangement.Top - ) { - Spacer(modifier = Modifier.height(48.dp)) - - Text( - text = "Enter verification code", - style = MaterialTheme.typography.headlineMedium, - textAlign = TextAlign.Center - ) - - Spacer(modifier = Modifier.height(16.dp)) - - Text( - text = "Enter the code we sent to ${state.e164}", - style = MaterialTheme.typography.bodyMedium, - textAlign = TextAlign.Center, - color = MaterialTheme.colorScheme.onSurfaceVariant - ) - - Spacer(modifier = Modifier.height(32.dp)) - - // Code input fields - XXX-XXX format - Row( - modifier = Modifier - .fillMaxWidth() - .testTag(TestTags.VERIFICATION_CODE_INPUT), - horizontalArrangement = Arrangement.Center, - verticalAlignment = Alignment.CenterVertically - ) { - // First three digits - for (i in 0..2) { - DigitField( - value = digits[i], - onValueChange = { newValue -> - if (newValue.length <= 1 && (newValue.isEmpty() || newValue.all { it.isDigit() })) { - digits = digits.toMutableList().apply { this[i] = newValue } - if (newValue.isNotEmpty() && i < 5) { - focusRequesters[i + 1].requestFocus() - } - } - }, - focusRequester = focusRequesters[i], - testTag = when (i) { - 0 -> TestTags.VERIFICATION_CODE_DIGIT_0 - 1 -> TestTags.VERIFICATION_CODE_DIGIT_1 - else -> TestTags.VERIFICATION_CODE_DIGIT_2 - } - ) - if (i < 2) { - Spacer(modifier = Modifier.width(4.dp)) - } - } - - // Separator - Text( - text = "-", - style = MaterialTheme.typography.headlineMedium, - modifier = Modifier.padding(horizontal = 8.dp) - ) - - // Last three digits - for (i in 3..5) { - if (i > 3) { - Spacer(modifier = Modifier.width(4.dp)) - } - DigitField( - value = digits[i], - onValueChange = { newValue -> - if (newValue.length <= 1 && (newValue.isEmpty() || newValue.all { it.isDigit() })) { - digits = digits.toMutableList().apply { this[i] = newValue } - if (newValue.isNotEmpty() && i < 5) { - focusRequesters[i + 1].requestFocus() - } - } - }, - focusRequester = focusRequesters[i], - testTag = when (i) { - 3 -> TestTags.VERIFICATION_CODE_DIGIT_3 - 4 -> TestTags.VERIFICATION_CODE_DIGIT_4 - else -> TestTags.VERIFICATION_CODE_DIGIT_5 - } - ) - } - } - - Spacer(modifier = Modifier.height(32.dp)) - - TextButton( - onClick = { onEvent(VerificationCodeScreenEvents.WrongNumber) }, - modifier = Modifier.testTag(TestTags.VERIFICATION_CODE_WRONG_NUMBER_BUTTON) - ) { - Text("Wrong number?") - } - - Spacer(modifier = Modifier.weight(1f)) - - TextButton( - onClick = { onEvent(VerificationCodeScreenEvents.ResendSms) }, - modifier = Modifier - .fillMaxWidth() - .testTag(TestTags.VERIFICATION_CODE_RESEND_SMS_BUTTON) - ) { - Text("Resend SMS") - } - - TextButton( - onClick = { onEvent(VerificationCodeScreenEvents.CallMe) }, - modifier = Modifier - .fillMaxWidth() - .testTag(TestTags.VERIFICATION_CODE_CALL_ME_BUTTON) - ) { - Text("Call me instead") - } } // Auto-focus first field on initial composition LaunchedEffect(Unit) { focusRequesters[0].requestFocus() } + + Scaffold( + snackbarHost = { SnackbarHost(snackbarHostState) }, + modifier = modifier + ) { innerPadding -> + Column( + modifier = Modifier + .fillMaxSize() + .padding(innerPadding) + .verticalScroll(scrollState) + .padding(horizontal = 24.dp), + horizontalAlignment = Alignment.CenterHorizontally + ) { + Spacer(modifier = Modifier.height(40.dp)) + + // Header + Text( + text = stringResource(R.string.VerificationCodeScreen__verification_code), + style = MaterialTheme.typography.headlineMedium, + modifier = Modifier + .fillMaxWidth() + .wrapContentWidth(Alignment.Start) + ) + + Spacer(modifier = Modifier.height(16.dp)) + + // Subheader with phone number + Text( + text = stringResource(R.string.VerificationCodeScreen__enter_the_code_we_sent_to_s, state.e164), + style = MaterialTheme.typography.bodyMedium, + color = MaterialTheme.colorScheme.onSurfaceVariant, + modifier = Modifier + .fillMaxWidth() + .wrapContentWidth(Alignment.Start) + ) + + Spacer(modifier = Modifier.height(8.dp)) + + // Wrong number button - aligned to start like in XML + TextButton( + onClick = { onEvent(VerificationCodeScreenEvents.WrongNumber) }, + modifier = Modifier + .fillMaxWidth() + .wrapContentWidth(Alignment.Start) + .testTag(TestTags.VERIFICATION_CODE_WRONG_NUMBER_BUTTON) + ) { + Text( + text = stringResource(R.string.VerificationCodeScreen__wrong_number), + color = MaterialTheme.colorScheme.primary + ) + } + + Spacer(modifier = Modifier.height(32.dp)) + + // Code input with spinner overlay when submitting + Box( + modifier = Modifier.fillMaxWidth(), + contentAlignment = Alignment.Center + ) { + // Code input fields - XXX-XXX format + Row( + modifier = Modifier + .fillMaxWidth() + .testTag(TestTags.VERIFICATION_CODE_INPUT), + horizontalArrangement = Arrangement.Center, + verticalAlignment = Alignment.CenterVertically + ) { + // First three digits + for (i in 0..2) { + DigitField( + value = digits[i], + onValueChange = { newValue, isBackspace -> + handleDigitChange( + index = i, + newValue = newValue, + isBackspace = isBackspace, + digits = digits, + focusRequesters = focusRequesters, + onDigitsChanged = { digits = it } + ) + }, + focusRequester = focusRequesters[i], + testTag = when (i) { + 0 -> TestTags.VERIFICATION_CODE_DIGIT_0 + 1 -> TestTags.VERIFICATION_CODE_DIGIT_1 + else -> TestTags.VERIFICATION_CODE_DIGIT_2 + }, + enabled = !state.isSubmittingCode + ) + if (i < 2) { + Spacer(modifier = Modifier.width(4.dp)) + } + } + + // Separator + Text( + text = "-", + style = MaterialTheme.typography.headlineMedium, + modifier = Modifier.padding(horizontal = 8.dp), + color = if (state.isSubmittingCode) MaterialTheme.colorScheme.onSurface.copy(alpha = 0.38f) else MaterialTheme.colorScheme.onSurface + ) + + // Last three digits + for (i in 3..5) { + if (i > 3) { + Spacer(modifier = Modifier.width(4.dp)) + } + DigitField( + value = digits[i], + onValueChange = { newValue, isBackspace -> + handleDigitChange( + index = i, + newValue = newValue, + isBackspace = isBackspace, + digits = digits, + focusRequesters = focusRequesters, + onDigitsChanged = { digits = it } + ) + }, + focusRequester = focusRequesters[i], + testTag = when (i) { + 3 -> TestTags.VERIFICATION_CODE_DIGIT_3 + 4 -> TestTags.VERIFICATION_CODE_DIGIT_4 + else -> TestTags.VERIFICATION_CODE_DIGIT_5 + }, + enabled = !state.isSubmittingCode + ) + } + } + + // Loading spinner overlay + if (state.isSubmittingCode) { + CircularProgressIndicator( + modifier = Modifier.size(48.dp) + ) + } + } + + Spacer(modifier = Modifier.height(32.dp)) + + // Having trouble button - shown after 3 incorrect code attempts (matching old behavior) + if (state.shouldShowHavingTrouble()) { + TextButton( + onClick = { onEvent(VerificationCodeScreenEvents.HavingTrouble) }, + modifier = Modifier + .fillMaxWidth() + .wrapContentWidth(Alignment.CenterHorizontally) + .testTag(TestTags.VERIFICATION_CODE_HAVING_TROUBLE_BUTTON) + ) { + Text( + text = stringResource(R.string.VerificationCodeScreen__having_trouble), + color = MaterialTheme.colorScheme.primary + ) + } + } + + Spacer(modifier = Modifier.weight(1f)) + + // Bottom buttons - Resend SMS and Call Me side by side + Row( + modifier = Modifier + .fillMaxWidth() + .padding(bottom = 32.dp), + horizontalArrangement = Arrangement.SpaceBetween, + verticalAlignment = Alignment.Bottom + ) { + // Resend SMS button with countdown — fits on one line if space allows, wraps if not + val canResendSms = state.canResendSms() + val disabledColor = MaterialTheme.colorScheme.onSurface.copy(alpha = 0.38f) + TextButton( + onClick = { onEvent(VerificationCodeScreenEvents.ResendSms) }, + enabled = canResendSms, + modifier = Modifier + .weight(1f) + .testTag(TestTags.VERIFICATION_CODE_RESEND_SMS_BUTTON) + ) { + Text( + text = if (canResendSms) { + stringResource(R.string.VerificationCodeScreen__resend_code) + } else { + val totalSeconds = state.rateLimits.smsResendTimeRemaining.inWholeSeconds.toInt() + val minutes = totalSeconds / 60 + val seconds = totalSeconds % 60 + stringResource(R.string.VerificationCodeScreen__resend_code) + " " + + stringResource(R.string.VerificationCodeScreen__countdown_format, minutes, seconds) + }, + color = if (canResendSms) MaterialTheme.colorScheme.primary else disabledColor, + textAlign = TextAlign.Center + ) + } + + Spacer(modifier = Modifier.width(8.dp)) + + // Call Me button with inline countdown + val canRequestCall = state.canRequestCall() + TextButton( + onClick = { onEvent(VerificationCodeScreenEvents.CallMe) }, + enabled = canRequestCall, + modifier = Modifier + .weight(1f) + .testTag(TestTags.VERIFICATION_CODE_CALL_ME_BUTTON) + ) { + Text( + text = if (canRequestCall) { + stringResource(R.string.VerificationCodeScreen__call_me_instead) + } else { + val totalSeconds = state.rateLimits.callRequestTimeRemaining.inWholeSeconds.toInt() + val minutes = totalSeconds / 60 + val seconds = totalSeconds % 60 + stringResource(R.string.VerificationCodeScreen__call_me_available_in, minutes, seconds) + }, + color = if (canRequestCall) MaterialTheme.colorScheme.primary else disabledColor, + textAlign = TextAlign.Center + ) + } + } + } + } } /** - * Individual digit input field + * Handles digit input changes including navigation between fields and backspace handling. + */ +private fun handleDigitChange( + index: Int, + newValue: String, + isBackspace: Boolean, + digits: List, + focusRequesters: List, + onDigitsChanged: (List) -> Unit +) { + when { + // Handle backspace on empty field - move to previous field + isBackspace && newValue.isEmpty() && index > 0 -> { + val newDigits = digits.toMutableList().apply { this[index] = "" } + onDigitsChanged(newDigits) + focusRequesters[index - 1].requestFocus() + } + // Handle new digit input + newValue.length <= 1 && (newValue.isEmpty() || newValue.all { it.isDigit() }) -> { + val newDigits = digits.toMutableList().apply { this[index] = newValue } + onDigitsChanged(newDigits) + // Move to next field if digit entered and not last field + if (newValue.isNotEmpty() && index < 5) { + focusRequesters[index + 1].requestFocus() + } + } + } +} + +/** + * Individual digit input field with backspace handling. */ @Composable private fun DigitField( value: String, - onValueChange: (String) -> Unit, + onValueChange: (String, Boolean) -> Unit, focusRequester: FocusRequester, testTag: String, - modifier: Modifier = Modifier + modifier: Modifier = Modifier, + enabled: Boolean = true ) { OutlinedTextField( value = value, - onValueChange = onValueChange, + onValueChange = { newValue -> + // Determine if this is a backspace (new value is empty and old value was not) + val isBackspace = newValue.isEmpty() && value.isNotEmpty() + onValueChange(newValue, isBackspace) + }, modifier = modifier - .width(44.dp) + .width(48.dp) .focusRequester(focusRequester) - .testTag(testTag), + .testTag(testTag) + .onKeyEvent { keyEvent -> + // Handle hardware backspace key when field is empty + if (keyEvent.type == KeyEventType.KeyUp && + (keyEvent.key == Key.Backspace || keyEvent.key == Key.Delete) && + value.isEmpty() + ) { + onValueChange("", true) + true + } else { + false + } + }, textStyle = MaterialTheme.typography.titleLarge.copy(textAlign = TextAlign.Center), singleLine = true, - keyboardOptions = KeyboardOptions(keyboardType = KeyboardType.Number) + keyboardOptions = KeyboardOptions(keyboardType = KeyboardType.Number), + enabled = enabled, + colors = OutlinedTextFieldDefaults.colors( + disabledTextColor = MaterialTheme.colorScheme.onSurface.copy(alpha = 0.38f), + disabledBorderColor = MaterialTheme.colorScheme.onSurface.copy(alpha = 0.12f) + ) ) } @@ -228,7 +443,40 @@ private fun DigitField( private fun VerificationCodeScreenPreview() { Previews.Preview { VerificationCodeScreen( - state = VerificationCodeState(), + state = VerificationCodeState( + e164 = "+1 555-123-4567" + ), + onEvent = {} + ) + } +} + +@DayNightPreviews +@Composable +private fun VerificationCodeScreenWithCountdownPreview() { + Previews.Preview { + VerificationCodeScreen( + state = VerificationCodeState( + e164 = "+1 555-123-4567", + rateLimits = SmsAndCallRateLimits( + smsResendTimeRemaining = 45.seconds, + callRequestTimeRemaining = 64.seconds + ) + ), + onEvent = {} + ) + } +} + +@DayNightPreviews +@Composable +private fun VerificationCodeScreenSubmittingPreview() { + Previews.Preview { + VerificationCodeScreen( + state = VerificationCodeState( + e164 = "+1 555-123-4567", + isSubmittingCode = true + ), onEvent = {} ) } diff --git a/feature/registration/src/main/java/org/signal/registration/screens/verificationcode/VerificationCodeScreenEvents.kt b/feature/registration/src/main/java/org/signal/registration/screens/verificationcode/VerificationCodeScreenEvents.kt index f0e4c40759..5ed7f8955c 100644 --- a/feature/registration/src/main/java/org/signal/registration/screens/verificationcode/VerificationCodeScreenEvents.kt +++ b/feature/registration/src/main/java/org/signal/registration/screens/verificationcode/VerificationCodeScreenEvents.kt @@ -12,4 +12,9 @@ sealed class VerificationCodeScreenEvents { data object CallMe : VerificationCodeScreenEvents() data object HavingTrouble : VerificationCodeScreenEvents() data object ConsumeInnerOneTimeEvent : VerificationCodeScreenEvents() + + /** + * Event to update countdown timers. Should be triggered periodically (e.g., every second). + */ + data object CountdownTick : VerificationCodeScreenEvents() } diff --git a/feature/registration/src/main/java/org/signal/registration/screens/verificationcode/VerificationCodeState.kt b/feature/registration/src/main/java/org/signal/registration/screens/verificationcode/VerificationCodeState.kt index eb93ce8755..20aba7eebe 100644 --- a/feature/registration/src/main/java/org/signal/registration/screens/verificationcode/VerificationCodeState.kt +++ b/feature/registration/src/main/java/org/signal/registration/screens/verificationcode/VerificationCodeState.kt @@ -7,10 +7,14 @@ package org.signal.registration.screens.verificationcode import org.signal.registration.NetworkController.SessionMetadata import kotlin.time.Duration +import kotlin.time.Duration.Companion.seconds data class VerificationCodeState( val sessionMetadata: SessionMetadata? = null, val e164: String = "", + val isSubmittingCode: Boolean = false, + val rateLimits: SmsAndCallRateLimits = SmsAndCallRateLimits(), + val incorrectCodeAttempts: Int = 0, val oneTimeEvent: OneTimeEvent? = null ) { sealed interface OneTimeEvent { @@ -22,4 +26,28 @@ data class VerificationCodeState( data object IncorrectVerificationCode : OneTimeEvent data object RegistrationError : OneTimeEvent } + + /** + * Returns true if the user can resend SMS (timer has expired) + */ + fun canResendSms(): Boolean = rateLimits.smsResendTimeRemaining <= 0.seconds + + /** + * Returns true if the user can request a call (timer has expired) + */ + fun canRequestCall(): Boolean = rateLimits.callRequestTimeRemaining <= 0.seconds + + /** + * Returns true if the "Having Trouble" button should be shown. + * Matches the old behavior of showing after 3 incorrect code attempts. + */ + fun shouldShowHavingTrouble(): Boolean = incorrectCodeAttempts >= 3 } + +/** + * Rate limit data for SMS resend and phone call request countdown timers. + */ +data class SmsAndCallRateLimits( + val smsResendTimeRemaining: Duration = 0.seconds, + val callRequestTimeRemaining: Duration = 0.seconds +) diff --git a/feature/registration/src/main/java/org/signal/registration/screens/verificationcode/VerificationCodeViewModel.kt b/feature/registration/src/main/java/org/signal/registration/screens/verificationcode/VerificationCodeViewModel.kt index c0abb01363..7f4ce51747 100644 --- a/feature/registration/src/main/java/org/signal/registration/screens/verificationcode/VerificationCodeViewModel.kt +++ b/feature/registration/src/main/java/org/signal/registration/screens/verificationcode/VerificationCodeViewModel.kt @@ -25,11 +25,15 @@ import org.signal.registration.RegistrationRoute import org.signal.registration.screens.util.navigateBack import org.signal.registration.screens.util.navigateTo import org.signal.registration.screens.verificationcode.VerificationCodeState.OneTimeEvent +import kotlin.time.Duration +import kotlin.time.Duration.Companion.milliseconds +import kotlin.time.Duration.Companion.seconds class VerificationCodeViewModel( private val repository: RegistrationRepository, private val parentState: StateFlow, - private val parentEventEmitter: (RegistrationFlowEvent) -> Unit + private val parentEventEmitter: (RegistrationFlowEvent) -> Unit, + private val clock: () -> Long = { System.currentTimeMillis() } ) : ViewModel() { companion object { @@ -39,24 +43,35 @@ class VerificationCodeViewModel( private val _localState = MutableStateFlow(VerificationCodeState()) val state = combine(_localState, parentState) { state, parentState -> applyParentState(state, parentState) } .onEach { Log.d(TAG, "[State] $it") } - .stateIn(viewModelScope, SharingStarted.Eagerly, VerificationCodeState()) + .stateIn(viewModelScope, SharingStarted.WhileSubscribed(5000), VerificationCodeState()) + + private var nextSmsAvailableAt: Duration = 0.seconds + private var nextCallAvailableAt: Duration = 0.seconds fun onEvent(event: VerificationCodeScreenEvents) { viewModelScope.launch { - _localState.emit(applyEvent(state.value, event)) + val stateEmitter: (VerificationCodeState) -> Unit = { newState -> + _localState.value = newState + } + applyEvent(state.value, event, stateEmitter) } } @VisibleForTesting - suspend fun applyEvent(state: VerificationCodeState, event: VerificationCodeScreenEvents): VerificationCodeState { - return when (event) { - is VerificationCodeScreenEvents.CodeEntered -> transformCodeEntered(state, event.code) + suspend fun applyEvent(state: VerificationCodeState, event: VerificationCodeScreenEvents, stateEmitter: (VerificationCodeState) -> Unit) { + val result = when (event) { + is VerificationCodeScreenEvents.CodeEntered -> { + stateEmitter(state.copy(isSubmittingCode = true)) + applyCodeEntered(state, event.code).copy(isSubmittingCode = false) + } is VerificationCodeScreenEvents.WrongNumber -> state.also { parentEventEmitter.navigateTo(RegistrationRoute.PhoneNumberEntry) } - is VerificationCodeScreenEvents.ResendSms -> transformResendCode(state, NetworkController.VerificationCodeTransport.SMS) - is VerificationCodeScreenEvents.CallMe -> transformResendCode(state, NetworkController.VerificationCodeTransport.VOICE) + is VerificationCodeScreenEvents.ResendSms -> applyResendCode(state, NetworkController.VerificationCodeTransport.SMS) + is VerificationCodeScreenEvents.CallMe -> applyResendCode(state, NetworkController.VerificationCodeTransport.VOICE) is VerificationCodeScreenEvents.HavingTrouble -> throw NotImplementedError("having trouble flow") // TODO [registration] - Having trouble flow is VerificationCodeScreenEvents.ConsumeInnerOneTimeEvent -> state.copy(oneTimeEvent = null) + is VerificationCodeScreenEvents.CountdownTick -> applyCountdownTick(state) } + stateEmitter(result) } @VisibleForTesting @@ -67,15 +82,38 @@ class VerificationCodeViewModel( return state } + val sessionChanged = state.sessionMetadata?.id != parentState.sessionMetadata.id + + val rateLimits = if (sessionChanged) { + computeRateLimits(parentState.sessionMetadata) + } else { + state.rateLimits + } + return state.copy( sessionMetadata = parentState.sessionMetadata, - e164 = parentState.sessionE164 + e164 = parentState.sessionE164, + rateLimits = rateLimits ) } - private suspend fun transformCodeEntered(inputState: VerificationCodeState, code: String): VerificationCodeState { - var state = inputState.copy() - var sessionMetadata = state.sessionMetadata ?: return state.also { parentEventEmitter(RegistrationFlowEvent.ResetState) } + /** + * Decrements countdown timers by 1 second, ensuring they don't go below 0. + */ + private fun applyCountdownTick(state: VerificationCodeState): VerificationCodeState { + return state.copy( + rateLimits = SmsAndCallRateLimits( + smsResendTimeRemaining = (state.rateLimits.smsResendTimeRemaining - 1.seconds).coerceAtLeast(0.seconds), + callRequestTimeRemaining = (state.rateLimits.callRequestTimeRemaining - 1.seconds).coerceAtLeast(0.seconds) + ) + ) + } + + private suspend fun applyCodeEntered(inputState: VerificationCodeState, code: String): VerificationCodeState { + var state = inputState + var sessionMetadata = state.sessionMetadata ?: return state.also { + parentEventEmitter(RegistrationFlowEvent.ResetState) + } // TODO should we be checking on whether we need to do more captcha stuff? @@ -89,7 +127,8 @@ class VerificationCodeViewModel( when (result.error) { is NetworkController.SubmitVerificationCodeError.InvalidSessionIdOrVerificationCode -> { Log.w(TAG, "[SubmitCode] Invalid sessionId or verification code entered. This is distinct from an *incorrect* verification code. Body: ${result.error.message}") - return state.copy(oneTimeEvent = OneTimeEvent.IncorrectVerificationCode) + val newAttempts = state.incorrectCodeAttempts + 1 + return state.copy(oneTimeEvent = OneTimeEvent.IncorrectVerificationCode, incorrectCodeAttempts = newAttempts) } is NetworkController.SubmitVerificationCodeError.SessionNotFound -> { Log.w(TAG, "[SubmitCode] Session not found: ${result.error.message}") @@ -114,6 +153,7 @@ class VerificationCodeViewModel( } } is NetworkController.RegistrationNetworkResult.NetworkError -> { + Log.w(TAG, "[SubmitCode] Network error.", result.exception) return state.copy(oneTimeEvent = OneTimeEvent.NetworkError) } is NetworkController.RegistrationNetworkResult.ApplicationError -> { @@ -126,7 +166,8 @@ class VerificationCodeViewModel( if (!sessionMetadata.verified) { Log.w(TAG, "[SubmitCode] Verification code was incorrect.") - return state.copy(oneTimeEvent = OneTimeEvent.IncorrectVerificationCode) + val newAttempts = state.incorrectCodeAttempts + 1 + return state.copy(oneTimeEvent = OneTimeEvent.IncorrectVerificationCode, incorrectCodeAttempts = newAttempts) } // Attempt to register @@ -192,68 +233,103 @@ class VerificationCodeViewModel( } } - private suspend fun transformResendCode( - inputState: VerificationCodeState, + private suspend fun applyResendCode( + state: VerificationCodeState, transport: NetworkController.VerificationCodeTransport ): VerificationCodeState { - val state = inputState.copy() if (state.sessionMetadata == null) { parentEventEmitter(RegistrationFlowEvent.ResetState) - return inputState + return state } - val sessionMetadata = state.sessionMetadata - val result = repository.requestVerificationCode( - sessionId = sessionMetadata.id, + sessionId = state.sessionMetadata.id, smsAutoRetrieveCodeSupported = false, transport = transport ) return when (result) { is NetworkController.RegistrationNetworkResult.Success -> { - state.copy(sessionMetadata = result.data) + Log.i(TAG, "[RequestCode][$transport] Successfully requested verification code.") + parentEventEmitter(RegistrationFlowEvent.SessionUpdated(result.data)) + state.copy( + sessionMetadata = result.data, + rateLimits = computeRateLimits(result.data) + ) } is NetworkController.RegistrationNetworkResult.Failure -> { when (result.error) { is NetworkController.RequestVerificationCodeError.InvalidRequest -> { + Log.w(TAG, "[RequestCode][$transport] Invalid request: ${result.error.message}") state.copy(oneTimeEvent = OneTimeEvent.UnknownError) } is NetworkController.RequestVerificationCodeError.RateLimited -> { - state.copy(oneTimeEvent = OneTimeEvent.RateLimited(result.error.retryAfter)) + Log.w(TAG, "[RequestCode][$transport] Rate limited (retryAfter: ${result.error.retryAfter}).") + parentEventEmitter(RegistrationFlowEvent.SessionUpdated(result.error.session)) + state.copy( + oneTimeEvent = OneTimeEvent.RateLimited(result.error.retryAfter), + sessionMetadata = result.error.session, + rateLimits = computeRateLimits(result.error.session) + ) } is NetworkController.RequestVerificationCodeError.CouldNotFulfillWithRequestedTransport -> { - state.copy(oneTimeEvent = OneTimeEvent.CouldNotRequestCodeWithSelectedTransport) + Log.w(TAG, "[RequestCode][$transport] Could not fulfill with requested transport.") + parentEventEmitter(RegistrationFlowEvent.SessionUpdated(result.error.session)) + state.copy( + oneTimeEvent = OneTimeEvent.CouldNotRequestCodeWithSelectedTransport, + sessionMetadata = result.error.session, + rateLimits = computeRateLimits(result.error.session) + ) } is NetworkController.RequestVerificationCodeError.InvalidSessionId -> { + Log.w(TAG, "[RequestCode][$transport] Invalid session ID: ${result.error.message}") // TODO don't start over, go back to phone number entry parentEventEmitter(RegistrationFlowEvent.ResetState) state } is NetworkController.RequestVerificationCodeError.MissingRequestInformationOrAlreadyVerified -> { - Log.w(TAG, "When requesting verification code, missing request information or already verified.") - state.copy(oneTimeEvent = OneTimeEvent.NetworkError) + Log.w(TAG, "[RequestCode][$transport] Missing request information or already verified.") + parentEventEmitter(RegistrationFlowEvent.SessionUpdated(result.error.session)) + state.copy( + oneTimeEvent = OneTimeEvent.NetworkError, + sessionMetadata = result.error.session, + rateLimits = computeRateLimits(result.error.session) + ) } is NetworkController.RequestVerificationCodeError.SessionNotFound -> { + Log.w(TAG, "[RequestCode][$transport] Session not found: ${result.error.message}") // TODO don't start over, go back to phone number entry parentEventEmitter(RegistrationFlowEvent.ResetState) state } is NetworkController.RequestVerificationCodeError.ThirdPartyServiceError -> { + Log.w(TAG, "[RequestCode][$transport] Third party service error. ${result.error.data}") state.copy(oneTimeEvent = OneTimeEvent.ThirdPartyError) } } } is NetworkController.RegistrationNetworkResult.NetworkError -> { + Log.w(TAG, "[RequestCode][$transport] Network error.", result.exception) state.copy(oneTimeEvent = OneTimeEvent.NetworkError) } is NetworkController.RegistrationNetworkResult.ApplicationError -> { - Log.w(TAG, "Unknown error when requesting verification code.", result.exception) + Log.w(TAG, "[RequestCode][$transport] Unknown application error.", result.exception) state.copy(oneTimeEvent = OneTimeEvent.UnknownError) } } } + private fun computeRateLimits(session: NetworkController.SessionMetadata): SmsAndCallRateLimits { + val now = clock().milliseconds + nextSmsAvailableAt = now + (session.nextSms?.seconds ?: nextSmsAvailableAt) + nextCallAvailableAt = now + (session.nextCall?.seconds ?: nextCallAvailableAt) + + return SmsAndCallRateLimits( + smsResendTimeRemaining = (nextSmsAvailableAt - clock().milliseconds).coerceAtLeast(0.seconds), + callRequestTimeRemaining = (nextCallAvailableAt - clock().milliseconds).coerceAtLeast(0.seconds) + ) + } + class Factory( private val repository: RegistrationRepository, private val parentState: StateFlow, diff --git a/feature/registration/src/main/java/org/signal/registration/test/TestTags.kt b/feature/registration/src/main/java/org/signal/registration/test/TestTags.kt index 82f32a1b9b..a9cdb721e4 100644 --- a/feature/registration/src/main/java/org/signal/registration/test/TestTags.kt +++ b/feature/registration/src/main/java/org/signal/registration/test/TestTags.kt @@ -41,4 +41,5 @@ object TestTags { const val VERIFICATION_CODE_WRONG_NUMBER_BUTTON = "verification_code_wrong_number_button" const val VERIFICATION_CODE_RESEND_SMS_BUTTON = "verification_code_resend_sms_button" const val VERIFICATION_CODE_CALL_ME_BUTTON = "verification_code_call_me_button" + const val VERIFICATION_CODE_HAVING_TROUBLE_BUTTON = "verification_code_having_trouble_button" } diff --git a/feature/registration/src/main/res/values/strings.xml b/feature/registration/src/main/res/values/strings.xml index b6f2452bf7..982a83df9c 100644 --- a/feature/registration/src/main/res/values/strings.xml +++ b/feature/registration/src/main/res/values/strings.xml @@ -33,6 +33,16 @@ Send photos, videos and files from your device. + + + Phone number + + You will receive a verification code. Carrier rates may apply. + + Select a country + Phone number + Next + Your country @@ -42,4 +52,36 @@ Search by name or number Unknown country + + + + Verification code + + Enter the code we sent to %s + + Wrong number? + + Resend Code + + Call me instead + + (%1$02d:%2$02d) + + Call me (%1$02d:%2$02d) + + Incorrect code + + Unable to connect. Please check your network and try again. + + Too many attempts. Try again in %s. + + An unexpected error occurred. Please try again. + + There was a problem sending your verification code. Please try again. + + Could not send code via the selected method. Please try another option. + + Registration failed. Please try again. + + Having trouble? diff --git a/feature/registration/src/test/java/org/signal/registration/screens/phonenumber/PhoneNumberEntryViewModelTest.kt b/feature/registration/src/test/java/org/signal/registration/screens/phonenumber/PhoneNumberEntryViewModelTest.kt index bce05255d9..c9cc1575fe 100644 --- a/feature/registration/src/test/java/org/signal/registration/screens/phonenumber/PhoneNumberEntryViewModelTest.kt +++ b/feature/registration/src/test/java/org/signal/registration/screens/phonenumber/PhoneNumberEntryViewModelTest.kt @@ -288,8 +288,8 @@ class PhoneNumberEntryViewModelTest { viewModel.applyEvent(initialState, PhoneNumberEntryScreenEvents.PhoneNumberSubmitted, stateEmitter, parentEventEmitter) // Verify spinner states - assertThat(emittedStates.first().showFullScreenSpinner).isTrue() - assertThat(emittedStates.last().showFullScreenSpinner).isFalse() + assertThat(emittedStates.first().showSpinner).isTrue() + assertThat(emittedStates.last().showSpinner).isFalse() assertThat(emittedStates.last().sessionMetadata).isNotNull() assertThat(emittedEvents).hasSize(1) @@ -314,8 +314,8 @@ class PhoneNumberEntryViewModelTest { viewModel.applyEvent(initialState, PhoneNumberEntryScreenEvents.PhoneNumberSubmitted, stateEmitter, parentEventEmitter) // Verify spinner states - assertThat(emittedStates.first().showFullScreenSpinner).isTrue() - assertThat(emittedStates.last().showFullScreenSpinner).isFalse() + assertThat(emittedStates.first().showSpinner).isTrue() + assertThat(emittedStates.last().showSpinner).isFalse() assertThat(emittedEvents).hasSize(1) assertThat(emittedEvents.first()) @@ -339,8 +339,8 @@ class PhoneNumberEntryViewModelTest { viewModel.applyEvent(initialState, PhoneNumberEntryScreenEvents.PhoneNumberSubmitted, stateEmitter, parentEventEmitter) // Verify spinner states - assertThat(emittedStates.first().showFullScreenSpinner).isTrue() - assertThat(emittedStates.last().showFullScreenSpinner).isFalse() + assertThat(emittedStates.first().showSpinner).isTrue() + assertThat(emittedStates.last().showSpinner).isFalse() assertThat(emittedStates.last().oneTimeEvent).isNotNull() .isInstanceOf() @@ -363,8 +363,8 @@ class PhoneNumberEntryViewModelTest { viewModel.applyEvent(initialState, PhoneNumberEntryScreenEvents.PhoneNumberSubmitted, stateEmitter, parentEventEmitter) // Verify spinner states - assertThat(emittedStates.first().showFullScreenSpinner).isTrue() - assertThat(emittedStates.last().showFullScreenSpinner).isFalse() + assertThat(emittedStates.first().showSpinner).isTrue() + assertThat(emittedStates.last().showSpinner).isFalse() assertThat(emittedStates.last().oneTimeEvent).isEqualTo(PhoneNumberEntryState.OneTimeEvent.UnknownError) } @@ -382,8 +382,8 @@ class PhoneNumberEntryViewModelTest { viewModel.applyEvent(initialState, PhoneNumberEntryScreenEvents.PhoneNumberSubmitted, stateEmitter, parentEventEmitter) // Verify spinner states - assertThat(emittedStates.first().showFullScreenSpinner).isTrue() - assertThat(emittedStates.last().showFullScreenSpinner).isFalse() + assertThat(emittedStates.first().showSpinner).isTrue() + assertThat(emittedStates.last().showSpinner).isFalse() assertThat(emittedStates.last().oneTimeEvent).isEqualTo(PhoneNumberEntryState.OneTimeEvent.NetworkError) } @@ -401,8 +401,8 @@ class PhoneNumberEntryViewModelTest { viewModel.applyEvent(initialState, PhoneNumberEntryScreenEvents.PhoneNumberSubmitted, stateEmitter, parentEventEmitter) // Verify spinner states - assertThat(emittedStates.first().showFullScreenSpinner).isTrue() - assertThat(emittedStates.last().showFullScreenSpinner).isFalse() + assertThat(emittedStates.first().showSpinner).isTrue() + assertThat(emittedStates.last().showSpinner).isFalse() assertThat(emittedStates.last().oneTimeEvent).isEqualTo(PhoneNumberEntryState.OneTimeEvent.UnknownError) } @@ -422,8 +422,8 @@ class PhoneNumberEntryViewModelTest { viewModel.applyEvent(initialState, PhoneNumberEntryScreenEvents.PhoneNumberSubmitted, stateEmitter, parentEventEmitter) // Verify spinner states - assertThat(emittedStates.first().showFullScreenSpinner).isTrue() - assertThat(emittedStates.last().showFullScreenSpinner).isFalse() + assertThat(emittedStates.first().showSpinner).isTrue() + assertThat(emittedStates.last().showSpinner).isFalse() // Should not create a new session, just request verification code assertThat(emittedEvents).hasSize(1) @@ -452,8 +452,8 @@ class PhoneNumberEntryViewModelTest { viewModel.applyEvent(initialState, PhoneNumberEntryScreenEvents.PhoneNumberSubmitted, stateEmitter, parentEventEmitter) // Verify spinner states - assertThat(emittedStates.first().showFullScreenSpinner).isTrue() - assertThat(emittedStates.last().showFullScreenSpinner).isFalse() + assertThat(emittedStates.first().showSpinner).isTrue() + assertThat(emittedStates.last().showSpinner).isFalse() assertThat(emittedStates.last().oneTimeEvent).isNotNull().isInstanceOf() } @@ -477,8 +477,8 @@ class PhoneNumberEntryViewModelTest { viewModel.applyEvent(initialState, PhoneNumberEntryScreenEvents.PhoneNumberSubmitted, stateEmitter, parentEventEmitter) // Verify spinner states - assertThat(emittedStates.first().showFullScreenSpinner).isTrue() - assertThat(emittedStates.last().showFullScreenSpinner).isFalse() + assertThat(emittedStates.first().showSpinner).isTrue() + assertThat(emittedStates.last().showSpinner).isFalse() assertThat(emittedEvents).hasSize(1) assertThat(emittedEvents.first()).isEqualTo(RegistrationFlowEvent.ResetState) @@ -503,8 +503,8 @@ class PhoneNumberEntryViewModelTest { viewModel.applyEvent(initialState, PhoneNumberEntryScreenEvents.PhoneNumberSubmitted, stateEmitter, parentEventEmitter) // Verify spinner states - assertThat(emittedStates.first().showFullScreenSpinner).isTrue() - assertThat(emittedStates.last().showFullScreenSpinner).isFalse() + assertThat(emittedStates.first().showSpinner).isTrue() + assertThat(emittedStates.last().showSpinner).isFalse() assertThat(emittedStates.last().oneTimeEvent).isEqualTo(PhoneNumberEntryState.OneTimeEvent.CouldNotRequestCodeWithSelectedTransport) } @@ -530,8 +530,8 @@ class PhoneNumberEntryViewModelTest { viewModel.applyEvent(initialState, PhoneNumberEntryScreenEvents.PhoneNumberSubmitted, stateEmitter, parentEventEmitter) // Verify spinner states - assertThat(emittedStates.first().showFullScreenSpinner).isTrue() - assertThat(emittedStates.last().showFullScreenSpinner).isFalse() + assertThat(emittedStates.first().showSpinner).isTrue() + assertThat(emittedStates.last().showSpinner).isFalse() assertThat(emittedStates.last().oneTimeEvent).isEqualTo(PhoneNumberEntryState.OneTimeEvent.ThirdPartyError) } @@ -559,8 +559,8 @@ class PhoneNumberEntryViewModelTest { viewModel.applyEvent(initialState, PhoneNumberEntryScreenEvents.PhoneNumberSubmitted, stateEmitter, parentEventEmitter) // Verify spinner states - assertThat(emittedStates.first().showFullScreenSpinner).isTrue() - assertThat(emittedStates.last().showFullScreenSpinner).isFalse() + assertThat(emittedStates.first().showSpinner).isTrue() + assertThat(emittedStates.last().showSpinner).isFalse() // Verify navigation to verification code entry assertThat(emittedEvents).hasSize(1) @@ -591,8 +591,8 @@ class PhoneNumberEntryViewModelTest { viewModel.applyEvent(initialState, PhoneNumberEntryScreenEvents.PhoneNumberSubmitted, stateEmitter, parentEventEmitter) // Verify spinner states - assertThat(emittedStates.first().showFullScreenSpinner).isTrue() - assertThat(emittedStates.last().showFullScreenSpinner).isFalse() + assertThat(emittedStates.first().showSpinner).isTrue() + assertThat(emittedStates.last().showSpinner).isFalse() // Verify navigation continues despite no push challenge token assertThat(emittedEvents).hasSize(1) @@ -627,8 +627,8 @@ class PhoneNumberEntryViewModelTest { viewModel.applyEvent(initialState, PhoneNumberEntryScreenEvents.PhoneNumberSubmitted, stateEmitter, parentEventEmitter) // Verify spinner states - assertThat(emittedStates.first().showFullScreenSpinner).isTrue() - assertThat(emittedStates.last().showFullScreenSpinner).isFalse() + assertThat(emittedStates.first().showSpinner).isTrue() + assertThat(emittedStates.last().showSpinner).isFalse() // Verify navigation continues despite push challenge submission failure assertThat(emittedEvents).hasSize(1) @@ -658,8 +658,8 @@ class PhoneNumberEntryViewModelTest { viewModel.applyEvent(initialState, PhoneNumberEntryScreenEvents.PhoneNumberSubmitted, stateEmitter, parentEventEmitter) // Verify spinner states - assertThat(emittedStates.first().showFullScreenSpinner).isTrue() - assertThat(emittedStates.last().showFullScreenSpinner).isFalse() + assertThat(emittedStates.first().showSpinner).isTrue() + assertThat(emittedStates.last().showSpinner).isFalse() // Verify navigation continues despite network error assertThat(emittedEvents).hasSize(1) @@ -689,8 +689,8 @@ class PhoneNumberEntryViewModelTest { viewModel.applyEvent(initialState, PhoneNumberEntryScreenEvents.PhoneNumberSubmitted, stateEmitter, parentEventEmitter) // Verify spinner states - assertThat(emittedStates.first().showFullScreenSpinner).isTrue() - assertThat(emittedStates.last().showFullScreenSpinner).isFalse() + assertThat(emittedStates.first().showSpinner).isTrue() + assertThat(emittedStates.last().showSpinner).isFalse() // Verify navigation continues despite application error assertThat(emittedEvents).hasSize(1) @@ -719,8 +719,8 @@ class PhoneNumberEntryViewModelTest { viewModel.applyEvent(initialState, PhoneNumberEntryScreenEvents.PhoneNumberSubmitted, stateEmitter, parentEventEmitter) // Verify spinner states - assertThat(emittedStates.first().showFullScreenSpinner).isTrue() - assertThat(emittedStates.last().showFullScreenSpinner).isFalse() + assertThat(emittedStates.first().showSpinner).isTrue() + assertThat(emittedStates.last().showSpinner).isFalse() // Verify navigation to captcha assertThat(emittedEvents).hasSize(1) diff --git a/feature/registration/src/test/java/org/signal/registration/screens/verificationcode/VerificationCodeScreenTest.kt b/feature/registration/src/test/java/org/signal/registration/screens/verificationcode/VerificationCodeScreenTest.kt index de618598bf..6ab4d7cd8e 100644 --- a/feature/registration/src/test/java/org/signal/registration/screens/verificationcode/VerificationCodeScreenTest.kt +++ b/feature/registration/src/test/java/org/signal/registration/screens/verificationcode/VerificationCodeScreenTest.kt @@ -49,7 +49,7 @@ class VerificationCodeScreenTest { } // Then - composeTestRule.onNodeWithText("Enter verification code").assertIsDisplayed() + composeTestRule.onNodeWithText("Verification code").assertIsDisplayed() } @Test @@ -70,7 +70,7 @@ class VerificationCodeScreenTest { composeTestRule.onNodeWithTag(TestTags.VERIFICATION_CODE_DIGIT_2).assertIsDisplayed() composeTestRule.onNodeWithTag(TestTags.VERIFICATION_CODE_DIGIT_3).assertIsDisplayed() composeTestRule.onNodeWithTag(TestTags.VERIFICATION_CODE_DIGIT_4).assertIsDisplayed() - composeTestRule.onNodeWithTag(TestTags.VERIFICATION_CODE_DIGIT_5).assertIsDisplayed() + composeTestRule.onNodeWithTag(TestTags.VERIFICATION_CODE_DIGIT_5).fetchSemanticsNode() } @Test @@ -191,7 +191,7 @@ class VerificationCodeScreenTest { // Then composeTestRule.onNodeWithText("Wrong number?").assertIsDisplayed() - composeTestRule.onNodeWithText("Resend SMS").assertIsDisplayed() + composeTestRule.onNodeWithText("Resend Code").assertIsDisplayed() composeTestRule.onNodeWithText("Call me instead").assertIsDisplayed() } } diff --git a/feature/registration/src/test/java/org/signal/registration/screens/verificationcode/VerificationCodeViewModelTest.kt b/feature/registration/src/test/java/org/signal/registration/screens/verificationcode/VerificationCodeViewModelTest.kt index a2a62ac21b..aea3de272c 100644 --- a/feature/registration/src/test/java/org/signal/registration/screens/verificationcode/VerificationCodeViewModelTest.kt +++ b/feature/registration/src/test/java/org/signal/registration/screens/verificationcode/VerificationCodeViewModelTest.kt @@ -11,6 +11,7 @@ import assertk.assertions.isEqualTo import assertk.assertions.isInstanceOf import assertk.assertions.isNotNull import assertk.assertions.isNull +import assertk.assertions.isTrue import assertk.assertions.prop import io.mockk.coEvery import io.mockk.mockk @@ -34,6 +35,8 @@ class VerificationCodeViewModelTest { private lateinit var parentState: MutableStateFlow private lateinit var emittedEvents: MutableList private lateinit var parentEventEmitter: (RegistrationFlowEvent) -> Unit + private lateinit var emittedStates: MutableList + private lateinit var stateEmitter: (VerificationCodeState) -> Unit @Before fun setup() { @@ -47,6 +50,8 @@ class VerificationCodeViewModelTest { ) emittedEvents = mutableListOf() parentEventEmitter = { event -> emittedEvents.add(event) } + emittedStates = mutableListOf() + stateEmitter = { state -> emittedStates.add(state) } viewModel = VerificationCodeViewModel(mockRepository, parentState, parentEventEmitter) } @@ -133,24 +138,26 @@ class VerificationCodeViewModelTest { oneTimeEvent = VerificationCodeState.OneTimeEvent.NetworkError ) - val result = viewModel.applyEvent( + viewModel.applyEvent( initialState, - VerificationCodeScreenEvents.ConsumeInnerOneTimeEvent + VerificationCodeScreenEvents.ConsumeInnerOneTimeEvent, + stateEmitter ) - assertThat(result.oneTimeEvent).isNull() + assertThat(emittedStates.last().oneTimeEvent).isNull() } @Test fun `ConsumeInnerOneTimeEvent with null event returns state with null event`() = runTest { val initialState = VerificationCodeState(oneTimeEvent = null) - val result = viewModel.applyEvent( + viewModel.applyEvent( initialState, - VerificationCodeScreenEvents.ConsumeInnerOneTimeEvent + VerificationCodeScreenEvents.ConsumeInnerOneTimeEvent, + stateEmitter ) - assertThat(result.oneTimeEvent).isNull() + assertThat(emittedStates.last().oneTimeEvent).isNull() } // ==================== applyEvent: WrongNumber Tests ==================== @@ -159,7 +166,7 @@ class VerificationCodeViewModelTest { fun `WrongNumber navigates to PhoneNumberEntry`() = runTest { val initialState = VerificationCodeState() - viewModel.applyEvent(initialState, VerificationCodeScreenEvents.WrongNumber) + viewModel.applyEvent(initialState, VerificationCodeScreenEvents.WrongNumber, stateEmitter) assertThat(emittedEvents).hasSize(1) assertThat(emittedEvents.first()) @@ -170,16 +177,42 @@ class VerificationCodeViewModelTest { // ==================== applyEvent: CodeEntered Tests ==================== + @Test + fun `CodeEntered emits isSubmittingCode true then false`() = runTest { + val sessionMetadata = createSessionMetadata() + val initialState = VerificationCodeState( + sessionMetadata = sessionMetadata, + e164 = "+15551234567" + ) + + coEvery { mockRepository.submitVerificationCode(any(), any()) } returns + NetworkController.RegistrationNetworkResult.Failure( + NetworkController.SubmitVerificationCodeError.InvalidSessionIdOrVerificationCode("Wrong code") + ) + + viewModel.applyEvent( + initialState, + VerificationCodeScreenEvents.CodeEntered("123456"), + stateEmitter + ) + + // First emitted state should have isSubmittingCode = true + assertThat(emittedStates.first().isSubmittingCode).isTrue() + // Final emitted state should have isSubmittingCode = false + assertThat(emittedStates.last().isSubmittingCode).isEqualTo(false) + } + @Test fun `CodeEntered emits ResetState when sessionMetadata is null`() = runTest { val initialState = VerificationCodeState(sessionMetadata = null) - val result = viewModel.applyEvent( + viewModel.applyEvent( initialState, - VerificationCodeScreenEvents.CodeEntered("123456") + VerificationCodeScreenEvents.CodeEntered("123456"), + stateEmitter ) - assertThat(result).isEqualTo(initialState) + assertThat(emittedStates.last()).isEqualTo(initialState) assertThat(emittedEvents).hasSize(1) assertThat(emittedEvents.first()) .isInstanceOf() @@ -201,7 +234,7 @@ class VerificationCodeViewModelTest { coEvery { mockRepository.registerAccountWithSession(any(), any(), any()) } returns NetworkController.RegistrationNetworkResult.Success(registerResponse to keyMaterial) - viewModel.applyEvent(initialState, VerificationCodeScreenEvents.CodeEntered("123456")) + viewModel.applyEvent(initialState, VerificationCodeScreenEvents.CodeEntered("123456"), stateEmitter) assertThat(emittedEvents).hasSize(2) assertThat(emittedEvents[0]).isInstanceOf() @@ -224,12 +257,13 @@ class VerificationCodeViewModelTest { NetworkController.SubmitVerificationCodeError.InvalidSessionIdOrVerificationCode("Wrong code") ) - val result = viewModel.applyEvent( + viewModel.applyEvent( initialState, - VerificationCodeScreenEvents.CodeEntered("123456") + VerificationCodeScreenEvents.CodeEntered("123456"), + stateEmitter ) - assertThat(result.oneTimeEvent).isEqualTo(VerificationCodeState.OneTimeEvent.IncorrectVerificationCode) + assertThat(emittedStates.last().oneTimeEvent).isEqualTo(VerificationCodeState.OneTimeEvent.IncorrectVerificationCode) } @Test @@ -245,7 +279,7 @@ class VerificationCodeViewModelTest { NetworkController.SubmitVerificationCodeError.SessionNotFound("Session expired") ) - viewModel.applyEvent(initialState, VerificationCodeScreenEvents.CodeEntered("123456")) + viewModel.applyEvent(initialState, VerificationCodeScreenEvents.CodeEntered("123456"), stateEmitter) assertThat(emittedEvents).hasSize(1) assertThat(emittedEvents.first()).isEqualTo(RegistrationFlowEvent.ResetState) @@ -269,7 +303,7 @@ class VerificationCodeViewModelTest { coEvery { mockRepository.registerAccountWithSession(any(), any(), any()) } returns NetworkController.RegistrationNetworkResult.Success(registerResponse to keyMaterial) - viewModel.applyEvent(initialState, VerificationCodeScreenEvents.CodeEntered("123456")) + viewModel.applyEvent(initialState, VerificationCodeScreenEvents.CodeEntered("123456"), stateEmitter) assertThat(emittedEvents).hasSize(2) assertThat(emittedEvents[0]).isInstanceOf() @@ -292,7 +326,7 @@ class VerificationCodeViewModelTest { NetworkController.SubmitVerificationCodeError.SessionAlreadyVerifiedOrNoCodeRequested(unverifiedSession) ) - viewModel.applyEvent(initialState, VerificationCodeScreenEvents.CodeEntered("123456")) + viewModel.applyEvent(initialState, VerificationCodeScreenEvents.CodeEntered("123456"), stateEmitter) assertThat(emittedEvents).hasSize(1) assertThat(emittedEvents.first()).isEqualTo(RegistrationFlowEvent.NavigateBack) @@ -311,12 +345,13 @@ class VerificationCodeViewModelTest { NetworkController.SubmitVerificationCodeError.RateLimited(60.seconds, sessionMetadata) ) - val result = viewModel.applyEvent( + viewModel.applyEvent( initialState, - VerificationCodeScreenEvents.CodeEntered("123456") + VerificationCodeScreenEvents.CodeEntered("123456"), + stateEmitter ) - assertThat(result.oneTimeEvent).isNotNull() + assertThat(emittedStates.last().oneTimeEvent).isNotNull() .isInstanceOf() .prop(VerificationCodeState.OneTimeEvent.RateLimited::retryAfter) .isEqualTo(60.seconds) @@ -333,12 +368,13 @@ class VerificationCodeViewModelTest { coEvery { mockRepository.submitVerificationCode(any(), any()) } returns NetworkController.RegistrationNetworkResult.NetworkError(java.io.IOException("Network error")) - val result = viewModel.applyEvent( + viewModel.applyEvent( initialState, - VerificationCodeScreenEvents.CodeEntered("123456") + VerificationCodeScreenEvents.CodeEntered("123456"), + stateEmitter ) - assertThat(result.oneTimeEvent).isEqualTo(VerificationCodeState.OneTimeEvent.NetworkError) + assertThat(emittedStates.last().oneTimeEvent).isEqualTo(VerificationCodeState.OneTimeEvent.NetworkError) } @Test @@ -352,12 +388,13 @@ class VerificationCodeViewModelTest { coEvery { mockRepository.submitVerificationCode(any(), any()) } returns NetworkController.RegistrationNetworkResult.ApplicationError(RuntimeException("Unexpected")) - val result = viewModel.applyEvent( + viewModel.applyEvent( initialState, - VerificationCodeScreenEvents.CodeEntered("123456") + VerificationCodeScreenEvents.CodeEntered("123456"), + stateEmitter ) - assertThat(result.oneTimeEvent).isEqualTo(VerificationCodeState.OneTimeEvent.UnknownError) + assertThat(emittedStates.last().oneTimeEvent).isEqualTo(VerificationCodeState.OneTimeEvent.UnknownError) } // ==================== applyEvent: CodeEntered - Registration Errors ==================== @@ -378,7 +415,7 @@ class VerificationCodeViewModelTest { NetworkController.RegisterAccountError.DeviceTransferPossible ) - viewModel.applyEvent(initialState, VerificationCodeScreenEvents.CodeEntered("123456")) + viewModel.applyEvent(initialState, VerificationCodeScreenEvents.CodeEntered("123456"), stateEmitter) assertThat(emittedEvents).hasSize(1) assertThat(emittedEvents.first()).isEqualTo(RegistrationFlowEvent.ResetState) @@ -400,12 +437,13 @@ class VerificationCodeViewModelTest { NetworkController.RegisterAccountError.RateLimited(30.seconds) ) - val result = viewModel.applyEvent( + viewModel.applyEvent( initialState, - VerificationCodeScreenEvents.CodeEntered("123456") + VerificationCodeScreenEvents.CodeEntered("123456"), + stateEmitter ) - assertThat(result.oneTimeEvent).isNotNull() + assertThat(emittedStates.last().oneTimeEvent).isNotNull() .isInstanceOf() .prop(VerificationCodeState.OneTimeEvent.RateLimited::retryAfter) .isEqualTo(30.seconds) @@ -427,12 +465,13 @@ class VerificationCodeViewModelTest { NetworkController.RegisterAccountError.InvalidRequest("Bad request") ) - val result = viewModel.applyEvent( + viewModel.applyEvent( initialState, - VerificationCodeScreenEvents.CodeEntered("123456") + VerificationCodeScreenEvents.CodeEntered("123456"), + stateEmitter ) - assertThat(result.oneTimeEvent).isEqualTo(VerificationCodeState.OneTimeEvent.RegistrationError) + assertThat(emittedStates.last().oneTimeEvent).isEqualTo(VerificationCodeState.OneTimeEvent.RegistrationError) } @Ignore @@ -451,12 +490,13 @@ class VerificationCodeViewModelTest { NetworkController.RegisterAccountError.RegistrationRecoveryPasswordIncorrect("Wrong password") ) - val result = viewModel.applyEvent( + viewModel.applyEvent( initialState, - VerificationCodeScreenEvents.CodeEntered("123456") + VerificationCodeScreenEvents.CodeEntered("123456"), + stateEmitter ) - assertThat(result.oneTimeEvent).isEqualTo(VerificationCodeState.OneTimeEvent.RegistrationError) + assertThat(emittedStates.last().oneTimeEvent).isEqualTo(VerificationCodeState.OneTimeEvent.RegistrationError) } @Ignore @@ -473,12 +513,13 @@ class VerificationCodeViewModelTest { coEvery { mockRepository.registerAccountWithSession(any(), any(), any()) } returns NetworkController.RegistrationNetworkResult.NetworkError(java.io.IOException("Network error")) - val result = viewModel.applyEvent( + viewModel.applyEvent( initialState, - VerificationCodeScreenEvents.CodeEntered("123456") + VerificationCodeScreenEvents.CodeEntered("123456"), + stateEmitter ) - assertThat(result.oneTimeEvent).isEqualTo(VerificationCodeState.OneTimeEvent.NetworkError) + assertThat(emittedStates.last().oneTimeEvent).isEqualTo(VerificationCodeState.OneTimeEvent.NetworkError) } @Ignore @@ -495,12 +536,13 @@ class VerificationCodeViewModelTest { coEvery { mockRepository.registerAccountWithSession(any(), any(), any()) } returns NetworkController.RegistrationNetworkResult.ApplicationError(RuntimeException("Unexpected")) - val result = viewModel.applyEvent( + viewModel.applyEvent( initialState, - VerificationCodeScreenEvents.CodeEntered("123456") + VerificationCodeScreenEvents.CodeEntered("123456"), + stateEmitter ) - assertThat(result.oneTimeEvent).isEqualTo(VerificationCodeState.OneTimeEvent.UnknownError) + assertThat(emittedStates.last().oneTimeEvent).isEqualTo(VerificationCodeState.OneTimeEvent.UnknownError) } // ==================== applyEvent: ResendSms Tests ==================== @@ -509,11 +551,11 @@ class VerificationCodeViewModelTest { fun `ResendSms with null sessionMetadata emits ResetState`() = runTest { val initialState = VerificationCodeState(sessionMetadata = null) - val result = viewModel.applyEvent(initialState, VerificationCodeScreenEvents.ResendSms) + viewModel.applyEvent(initialState, VerificationCodeScreenEvents.ResendSms, stateEmitter) assertThat(emittedEvents).hasSize(1) assertThat(emittedEvents.first()).isEqualTo(RegistrationFlowEvent.ResetState) - assertThat(result).isEqualTo(initialState) + assertThat(emittedStates.last()).isEqualTo(initialState) } @Test @@ -525,9 +567,9 @@ class VerificationCodeViewModelTest { coEvery { mockRepository.requestVerificationCode(any(), any(), eq(NetworkController.VerificationCodeTransport.SMS)) } returns NetworkController.RegistrationNetworkResult.Success(updatedSession) - val result = viewModel.applyEvent(initialState, VerificationCodeScreenEvents.ResendSms) + viewModel.applyEvent(initialState, VerificationCodeScreenEvents.ResendSms, stateEmitter) - assertThat(result.sessionMetadata).isEqualTo(updatedSession) + assertThat(emittedStates.last().sessionMetadata).isEqualTo(updatedSession) } @Test @@ -540,9 +582,9 @@ class VerificationCodeViewModelTest { NetworkController.RequestVerificationCodeError.RateLimited(45.seconds, sessionMetadata) ) - val result = viewModel.applyEvent(initialState, VerificationCodeScreenEvents.ResendSms) + viewModel.applyEvent(initialState, VerificationCodeScreenEvents.ResendSms, stateEmitter) - assertThat(result.oneTimeEvent).isNotNull() + assertThat(emittedStates.last().oneTimeEvent).isNotNull() .isInstanceOf() .prop(VerificationCodeState.OneTimeEvent.RateLimited::retryAfter) .isEqualTo(45.seconds) @@ -558,9 +600,9 @@ class VerificationCodeViewModelTest { NetworkController.RequestVerificationCodeError.InvalidRequest("Bad request") ) - val result = viewModel.applyEvent(initialState, VerificationCodeScreenEvents.ResendSms) + viewModel.applyEvent(initialState, VerificationCodeScreenEvents.ResendSms, stateEmitter) - assertThat(result.oneTimeEvent).isEqualTo(VerificationCodeState.OneTimeEvent.UnknownError) + assertThat(emittedStates.last().oneTimeEvent).isEqualTo(VerificationCodeState.OneTimeEvent.UnknownError) } @Test @@ -573,9 +615,9 @@ class VerificationCodeViewModelTest { NetworkController.RequestVerificationCodeError.CouldNotFulfillWithRequestedTransport(sessionMetadata) ) - val result = viewModel.applyEvent(initialState, VerificationCodeScreenEvents.ResendSms) + viewModel.applyEvent(initialState, VerificationCodeScreenEvents.ResendSms, stateEmitter) - assertThat(result.oneTimeEvent).isEqualTo(VerificationCodeState.OneTimeEvent.CouldNotRequestCodeWithSelectedTransport) + assertThat(emittedStates.last().oneTimeEvent).isEqualTo(VerificationCodeState.OneTimeEvent.CouldNotRequestCodeWithSelectedTransport) } @Test @@ -588,7 +630,7 @@ class VerificationCodeViewModelTest { NetworkController.RequestVerificationCodeError.InvalidSessionId("Invalid session") ) - viewModel.applyEvent(initialState, VerificationCodeScreenEvents.ResendSms) + viewModel.applyEvent(initialState, VerificationCodeScreenEvents.ResendSms, stateEmitter) assertThat(emittedEvents).hasSize(1) assertThat(emittedEvents.first()).isEqualTo(RegistrationFlowEvent.ResetState) @@ -604,7 +646,7 @@ class VerificationCodeViewModelTest { NetworkController.RequestVerificationCodeError.SessionNotFound("Session not found") ) - viewModel.applyEvent(initialState, VerificationCodeScreenEvents.ResendSms) + viewModel.applyEvent(initialState, VerificationCodeScreenEvents.ResendSms, stateEmitter) assertThat(emittedEvents).hasSize(1) assertThat(emittedEvents.first()).isEqualTo(RegistrationFlowEvent.ResetState) @@ -620,9 +662,9 @@ class VerificationCodeViewModelTest { NetworkController.RequestVerificationCodeError.MissingRequestInformationOrAlreadyVerified(sessionMetadata) ) - val result = viewModel.applyEvent(initialState, VerificationCodeScreenEvents.ResendSms) + viewModel.applyEvent(initialState, VerificationCodeScreenEvents.ResendSms, stateEmitter) - assertThat(result.oneTimeEvent).isEqualTo(VerificationCodeState.OneTimeEvent.NetworkError) + assertThat(emittedStates.last().oneTimeEvent).isEqualTo(VerificationCodeState.OneTimeEvent.NetworkError) } @Test @@ -637,9 +679,9 @@ class VerificationCodeViewModelTest { ) ) - val result = viewModel.applyEvent(initialState, VerificationCodeScreenEvents.ResendSms) + viewModel.applyEvent(initialState, VerificationCodeScreenEvents.ResendSms, stateEmitter) - assertThat(result.oneTimeEvent).isEqualTo(VerificationCodeState.OneTimeEvent.ThirdPartyError) + assertThat(emittedStates.last().oneTimeEvent).isEqualTo(VerificationCodeState.OneTimeEvent.ThirdPartyError) } @Test @@ -650,9 +692,9 @@ class VerificationCodeViewModelTest { coEvery { mockRepository.requestVerificationCode(any(), any(), eq(NetworkController.VerificationCodeTransport.SMS)) } returns NetworkController.RegistrationNetworkResult.NetworkError(java.io.IOException("Network error")) - val result = viewModel.applyEvent(initialState, VerificationCodeScreenEvents.ResendSms) + viewModel.applyEvent(initialState, VerificationCodeScreenEvents.ResendSms, stateEmitter) - assertThat(result.oneTimeEvent).isEqualTo(VerificationCodeState.OneTimeEvent.NetworkError) + assertThat(emittedStates.last().oneTimeEvent).isEqualTo(VerificationCodeState.OneTimeEvent.NetworkError) } @Test @@ -663,9 +705,9 @@ class VerificationCodeViewModelTest { coEvery { mockRepository.requestVerificationCode(any(), any(), eq(NetworkController.VerificationCodeTransport.SMS)) } returns NetworkController.RegistrationNetworkResult.ApplicationError(RuntimeException("Unexpected")) - val result = viewModel.applyEvent(initialState, VerificationCodeScreenEvents.ResendSms) + viewModel.applyEvent(initialState, VerificationCodeScreenEvents.ResendSms, stateEmitter) - assertThat(result.oneTimeEvent).isEqualTo(VerificationCodeState.OneTimeEvent.UnknownError) + assertThat(emittedStates.last().oneTimeEvent).isEqualTo(VerificationCodeState.OneTimeEvent.UnknownError) } // ==================== applyEvent: CallMe Tests ==================== @@ -674,11 +716,11 @@ class VerificationCodeViewModelTest { fun `CallMe with null sessionMetadata emits ResetState`() = runTest { val initialState = VerificationCodeState(sessionMetadata = null) - val result = viewModel.applyEvent(initialState, VerificationCodeScreenEvents.CallMe) + viewModel.applyEvent(initialState, VerificationCodeScreenEvents.CallMe, stateEmitter) assertThat(emittedEvents).hasSize(1) assertThat(emittedEvents.first()).isEqualTo(RegistrationFlowEvent.ResetState) - assertThat(result).isEqualTo(initialState) + assertThat(emittedStates.last()).isEqualTo(initialState) } @Test @@ -690,9 +732,9 @@ class VerificationCodeViewModelTest { coEvery { mockRepository.requestVerificationCode(any(), any(), eq(NetworkController.VerificationCodeTransport.VOICE)) } returns NetworkController.RegistrationNetworkResult.Success(updatedSession) - val result = viewModel.applyEvent(initialState, VerificationCodeScreenEvents.CallMe) + viewModel.applyEvent(initialState, VerificationCodeScreenEvents.CallMe, stateEmitter) - assertThat(result.sessionMetadata).isEqualTo(updatedSession) + assertThat(emittedStates.last().sessionMetadata).isEqualTo(updatedSession) } @Test @@ -705,9 +747,9 @@ class VerificationCodeViewModelTest { NetworkController.RequestVerificationCodeError.RateLimited(90.seconds, sessionMetadata) ) - val result = viewModel.applyEvent(initialState, VerificationCodeScreenEvents.CallMe) + viewModel.applyEvent(initialState, VerificationCodeScreenEvents.CallMe, stateEmitter) - assertThat(result.oneTimeEvent).isNotNull() + assertThat(emittedStates.last().oneTimeEvent).isNotNull() .isInstanceOf() .prop(VerificationCodeState.OneTimeEvent.RateLimited::retryAfter) .isEqualTo(90.seconds) @@ -723,9 +765,9 @@ class VerificationCodeViewModelTest { NetworkController.RequestVerificationCodeError.CouldNotFulfillWithRequestedTransport(sessionMetadata) ) - val result = viewModel.applyEvent(initialState, VerificationCodeScreenEvents.CallMe) + viewModel.applyEvent(initialState, VerificationCodeScreenEvents.CallMe, stateEmitter) - assertThat(result.oneTimeEvent).isEqualTo(VerificationCodeState.OneTimeEvent.CouldNotRequestCodeWithSelectedTransport) + assertThat(emittedStates.last().oneTimeEvent).isEqualTo(VerificationCodeState.OneTimeEvent.CouldNotRequestCodeWithSelectedTransport) } @Test @@ -740,9 +782,9 @@ class VerificationCodeViewModelTest { ) ) - val result = viewModel.applyEvent(initialState, VerificationCodeScreenEvents.CallMe) + viewModel.applyEvent(initialState, VerificationCodeScreenEvents.CallMe, stateEmitter) - assertThat(result.oneTimeEvent).isEqualTo(VerificationCodeState.OneTimeEvent.ThirdPartyError) + assertThat(emittedStates.last().oneTimeEvent).isEqualTo(VerificationCodeState.OneTimeEvent.ThirdPartyError) } // ==================== Helper Functions ====================