diff --git a/.github/workflows/deb-builds.yaml b/.github/workflows/deb-builds.yaml index 22e6951c963..e0e3005b2cb 100644 --- a/.github/workflows/deb-builds.yaml +++ b/.github/workflows/deb-builds.yaml @@ -47,10 +47,10 @@ jobs: target_system="${{ inputs.os }}-${{ inputs.os-version }}" case "$target_system" in debian-sid) - cp -av packaging/debian-sid debian + ln -sfn packaging/debian-sid debian ;; ubuntu-*) - cp -av packaging/ubuntu-16.04 debian + ln -sfn packaging/ubuntu-16.04 debian ;; *) echo "unsupported deb packaging for $target_system" diff --git a/cmd/libsnap-confine-private/mountinfo-test.c b/cmd/libsnap-confine-private/mountinfo-test.c index 2078d4634e0..d54b699c282 100644 --- a/cmd/libsnap-confine-private/mountinfo-test.c +++ b/cmd/libsnap-confine-private/mountinfo-test.c @@ -215,6 +215,55 @@ static void test_parse_mountinfo_entry__broken_octal_escaping(void) { g_assert_null(entry->next); } +static void test_parse_mountinfo_entry__partial_escape_oob(void) { + // Regression tests: partial octal escape sequences (fewer than 3 octal + // digits after the backslash, or just a trailing backslash) must not cause + // out-of-bounds reads. Each partial escape is copied verbatim. + const char *line; + struct sc_mountinfo_entry *entry; + + // Lone trailing backslash at end of string. + line = "2074 27 0:54 / /tmp/dir rw - tmpfs source rw\\"; + entry = sc_parse_mountinfo_entry(line); + g_assert_nonnull(entry); + g_test_queue_destroy((GDestroyNotify)sc_free_mountinfo_entry, entry); + g_assert_cmpstr(entry->mount_source, ==, "source"); + g_assert_cmpstr(entry->super_opts, ==, "rw\\"); + + // Backslash followed by one octal digit at end of string. + line = "2074 27 0:54 / /tmp/dir rw - tmpfs source rw\\0"; + entry = sc_parse_mountinfo_entry(line); + g_assert_nonnull(entry); + g_test_queue_destroy((GDestroyNotify)sc_free_mountinfo_entry, entry); + g_assert_cmpstr(entry->mount_source, ==, "source"); + g_assert_cmpstr(entry->super_opts, ==, "rw\\0"); + + // Backslash followed by two octal digits at end of string. + line = "2074 27 0:54 / /tmp/dir rw - tmpfs source rw\\05"; + entry = sc_parse_mountinfo_entry(line); + g_assert_nonnull(entry); + g_test_queue_destroy((GDestroyNotify)sc_free_mountinfo_entry, entry); + g_assert_cmpstr(entry->mount_source, ==, "source"); + g_assert_cmpstr(entry->super_opts, ==, "rw\\05"); + + // Backslash followed by one octal digit then space (partial escape at + // end of a space-delimited field, not at end of string). + line = "2074 27 0:54 / /tmp/dir rw - tmpfs source\\5 rw"; + entry = sc_parse_mountinfo_entry(line); + g_assert_nonnull(entry); + g_test_queue_destroy((GDestroyNotify)sc_free_mountinfo_entry, entry); + g_assert_cmpstr(entry->mount_source, ==, "source\\5"); + g_assert_cmpstr(entry->super_opts, ==, "rw"); + + // Backslash followed by two octal digits then space. + line = "2074 27 0:54 / /tmp/dir rw - tmpfs source\\57 rw"; + entry = sc_parse_mountinfo_entry(line); + g_assert_nonnull(entry); + g_test_queue_destroy((GDestroyNotify)sc_free_mountinfo_entry, entry); + g_assert_cmpstr(entry->mount_source, ==, "source\\57"); + g_assert_cmpstr(entry->super_opts, ==, "rw"); +} + static void test_parse_mountinfo_entry__unescaped_whitespace(void) { // The kernel does not escape '\r' const char *line = "2074 27 0:54 / /tmp/strange\rdir rw,relatime shared:1039 - tmpfs tmpfs rw"; @@ -271,6 +320,8 @@ static void __attribute__((constructor)) init(void) { g_test_add_func("/mountinfo/parse_mountinfo_entry/octal_escaping", test_parse_mountinfo_entry__octal_escaping); g_test_add_func("/mountinfo/parse_mountinfo_entry/broken_octal_escaping", test_parse_mountinfo_entry__broken_octal_escaping); + g_test_add_func("/mountinfo/parse_mountinfo_entry/partial_escape_oob", + test_parse_mountinfo_entry__partial_escape_oob); g_test_add_func("/mountinfo/parse_mountinfo_entry/unescaped_whitespace", test_parse_mountinfo_entry__unescaped_whitespace); g_test_add_func("/mountinfo/parse_mountinfo_entry/broken_9p_superblock", diff --git a/cmd/libsnap-confine-private/mountinfo.c b/cmd/libsnap-confine-private/mountinfo.c index 3b84a585e81..24d1f991584 100644 --- a/cmd/libsnap-confine-private/mountinfo.c +++ b/cmd/libsnap-confine-private/mountinfo.c @@ -134,8 +134,9 @@ static char *parse_next_string_field_ex(sc_mountinfo_entry *entry, const char *l bool allow_spaces_in_field) { const char *input = &line[*offset]; char *output = &entry->line_buf[*offset]; - size_t input_idx = 0; // reading index - size_t output_idx = 0; // writing index + size_t input_idx = 0; // reading index + size_t output_idx = 0; // writing index + size_t input_len = strlen(input); // length of remaining input (used for bounds checks below) // Scan characters until we run out of memory to scan or we find a // space. The kernel uses simple octal escape sequences for the @@ -169,14 +170,13 @@ static char *parse_next_string_field_ex(sc_mountinfo_entry *entry, const char *l break; } else if (c == '\\') { // Three *more* octal digits required for the escape - // sequence. For reference see mangle_path() in - // fs/seq_file.c. Note that is_octal_digit returns - // false on the string terminator character NUL and the - // short-circuiting behavior of && makes this check - // correct even if '\\' is the last character of the - // string. + // sequence. For reference see mangle_path() in + // fs/seq_file.c. We explicitly verify that at least 3 + // more bytes remain before the end of the string to + // prevent out-of-bounds reads when the input contains + // fewer than 3 bytes after the backslash. const char *s = &input[input_idx]; - if (is_octal_digit(s[1]) && is_octal_digit(s[2]) && is_octal_digit(s[3])) { + if (input_idx + 4 <= input_len && is_octal_digit(s[1]) && is_octal_digit(s[2]) && is_octal_digit(s[3])) { // Unescape the octal value encoded in s[1], // s[2] and s[3]. Because we are working with // byte values there are no issues related to diff --git a/cmd/snapd/export_test.go b/cmd/snapd/export_test.go index dc67d86fbae..2bc816fd88d 100644 --- a/cmd/snapd/export_test.go +++ b/cmd/snapd/export_test.go @@ -21,10 +21,14 @@ package main import ( "time" + + "github.com/snapcore/snapd/seclog" ) var ( - Run = run + Run = run + SetupSecurityLogger = setupSecurityLogger + DisableSecurityLogger = disableSecurityLogger ) func MockSyscheckCheckSystem(f func() error) (restore func()) { @@ -35,6 +39,22 @@ func MockSyscheckCheckSystem(f func() error) (restore func()) { } } +func MockSeclogSetup(f func(seclog.Impl, seclog.Sink, string, seclog.Level) error) (restore func()) { + old := seclogSetup + seclogSetup = f + return func() { + seclogSetup = old + } +} + +func MockSeclogDisable(f func() error) (restore func()) { + old := seclogDisable + seclogDisable = f + return func() { + seclogDisable = old + } +} + func MockCheckRunningConditionsRetryDelay(d time.Duration) (restore func()) { oldCheckRunningConditionsRetryDelay := checkRunningConditionsRetryDelay checkRunningConditionsRetryDelay = d diff --git a/cmd/snapd/main.go b/cmd/snapd/main.go index 89c3ffd9580..7214ecfc43b 100644 --- a/cmd/snapd/main.go +++ b/cmd/snapd/main.go @@ -33,6 +33,7 @@ import ( "github.com/snapcore/snapd/osutil" "github.com/snapcore/snapd/sandbox" "github.com/snapcore/snapd/secboot" + "github.com/snapcore/snapd/seclog" "github.com/snapcore/snapd/snapdenv" "github.com/snapcore/snapd/snapdtool" "github.com/snapcore/snapd/syscheck" @@ -41,13 +42,33 @@ import ( var ( syscheckCheckSystem = syscheck.CheckSystem + seclogSetup = seclog.Setup + seclogDisable = seclog.Disable ) +const secLogAppID = "canonical.snapd.snapd" +const secLogMinLevel seclog.Level = seclog.LevelInfo + +func setupSecurityLogger() { + if err := seclogSetup(seclog.ImplSlog, seclog.SinkAudit, secLogAppID, secLogMinLevel); err != nil { + logger.Noticef("WARNING: %v", err) + } +} + +func disableSecurityLogger() { + if err := seclogDisable(); err != nil { + logger.Noticef("WARNING: cannot disable security logger: %v", err) + } +} + func init() { logger.SimpleSetup(nil) + setupSecurityLogger() } func main() { + defer disableSecurityLogger() + // When preseeding re-exec is not used if snapdenv.Preseeding() { logger.Noticef("running for preseeding") diff --git a/cmd/snapd/main_test.go b/cmd/snapd/main_test.go index bb69327bc83..041ee3fa0fb 100644 --- a/cmd/snapd/main_test.go +++ b/cmd/snapd/main_test.go @@ -35,6 +35,7 @@ import ( "github.com/snapcore/snapd/interfaces/seccomp" "github.com/snapcore/snapd/logger" "github.com/snapcore/snapd/osutil" + "github.com/snapcore/snapd/seclog" "github.com/snapcore/snapd/testutil" ) @@ -60,6 +61,48 @@ func (s *snapdSuite) SetUpTest(c *C) { s.AddCleanup(restore) } +func (s *snapdSuite) TestSetupSecurityLoggerWarnsOnError(c *C) { + logbuf, restore := logger.MockLogger() + defer restore() + + restore = snapd.MockSeclogSetup(func(impl seclog.Impl, sink seclog.Sink, appID string, level seclog.Level) error { + return fmt.Errorf("security logger disabled: cannot open audit socket: permission denied") + }) + defer restore() + + snapd.SetupSecurityLogger() + + c.Check(logbuf.String(), testutil.Contains, "WARNING: security logger disabled: cannot open audit socket: permission denied") +} + +func (s *snapdSuite) TestDisableSecurityLoggerCallsDisable(c *C) { + _, restore := logger.MockLogger() + defer restore() + + disabled := false + restore = snapd.MockSeclogDisable(func() error { + disabled = true + return nil + }) + defer restore() + + snapd.DisableSecurityLogger() + c.Check(disabled, Equals, true) +} + +func (s *snapdSuite) TestDisableSecurityLoggerWarnsOnError(c *C) { + logbuf, restore := logger.MockLogger() + defer restore() + + restore = snapd.MockSeclogDisable(func() error { + return fmt.Errorf("audit socket busy") + }) + defer restore() + + snapd.DisableSecurityLogger() + c.Check(logbuf.String(), testutil.Contains, "WARNING: cannot disable security logger: audit socket busy") +} + func (s *snapdSuite) TestSyscheckFailGoesIntoDegradedMode(c *C) { logbuf, restore := logger.MockLogger() defer restore() diff --git a/core-initrd/24.04/debian/changelog b/core-initrd/24.04/debian/changelog index 637d6b69667..1ddb1df6430 100644 --- a/core-initrd/24.04/debian/changelog +++ b/core-initrd/24.04/debian/changelog @@ -1,3 +1,9 @@ +ubuntu-core-initramfs (69+2.75.2+g199.9a8c2f3+24.04) noble; urgency=medium + + * Update to snapd version 2.75.2+g199.9a8c2f3 + + -- Alfonso Sanchez-Beato Wed, 15 Apr 2026 16:49:42 -0400 + ubuntu-core-initramfs (69+2.75+g75.4b39daa+24.04) noble; urgency=medium * Update to snapd version 2.75+g75.4b39daa diff --git a/core-initrd/25.10/debian/changelog b/core-initrd/25.10/debian/changelog index 62279b587ec..96fe311a6a7 100644 --- a/core-initrd/25.10/debian/changelog +++ b/core-initrd/25.10/debian/changelog @@ -1,3 +1,9 @@ +ubuntu-core-initramfs (72+2.75.2+g199.9a8c2f3+25.10) questing; urgency=medium + + * Update to snapd version 2.75.2+g199.9a8c2f3 + + -- Alfonso Sanchez-Beato Wed, 15 Apr 2026 16:50:12 -0400 + ubuntu-core-initramfs (72+2.75+g75.4b39daa+25.10) questing; urgency=medium * Update to snapd version 2.75+g75.4b39daa diff --git a/core-initrd/26.04/debian/changelog b/core-initrd/26.04/debian/changelog index ed61e99486d..0b08697e2b6 100644 --- a/core-initrd/26.04/debian/changelog +++ b/core-initrd/26.04/debian/changelog @@ -1,3 +1,9 @@ +ubuntu-core-initramfs (73+2.75.2+g199.9a8c2f3+26.04) resolute; urgency=medium + + * Update to snapd version 2.75.2+g199.9a8c2f3 + + -- Alfonso Sanchez-Beato Wed, 15 Apr 2026 16:50:37 -0400 + ubuntu-core-initramfs (73+2.75+g75.4b39daa+26.04) resolute; urgency=medium * Update to snapd version 2.75+g75.4b39daa diff --git a/daemon/api_users.go b/daemon/api_users.go index 4550dbee2e7..0acd05c284b 100644 --- a/daemon/api_users.go +++ b/daemon/api_users.go @@ -31,6 +31,7 @@ import ( "github.com/snapcore/snapd/overlord/devicestate" "github.com/snapcore/snapd/overlord/state" "github.com/snapcore/snapd/release" + "github.com/snapcore/snapd/seclog" "github.com/snapcore/snapd/store" ) @@ -68,6 +69,9 @@ var ( deviceStateCreateUser = devicestate.CreateUser deviceStateCreateKnownUsers = devicestate.CreateKnownUsers deviceStateRemoveUser = devicestate.RemoveUser + + seclogLogLoginSuccess = seclog.LogLoginSuccess + seclogLogLoginFailure = seclog.LogLoginFailure ) // userResponseData contains the data releated to user creation/login/query @@ -83,6 +87,17 @@ type userResponseData struct { var isEmailish = regexp.MustCompile(`.@.*\..`).MatchString +// loginError logs a login failure to the security audit log and returns resp +// unchanged. It is a convenience wrapper so that each error return path in +// loginUser can log with a single call. +func loginError(resp *apiError, snapdUser seclog.SnapdUser, code string) *apiError { + seclogLogLoginFailure(snapdUser, seclog.Reason{ + Code: code, + Message: resp.Message, + }) + return resp +} + func loginUser(c *Command, r *http.Request, user *auth.UserState) Response { var loginData struct { Username string `json:"username"` @@ -116,41 +131,53 @@ func loginUser(c *Command, r *http.Request, user *auth.UserState) Response { } } + // Build the user identity for security audit logging. At this + // point we know the email and optional username; the numeric ID + // is only available after successful authentication. + snapdUser := seclog.SnapdUser{ + SystemUserName: loginData.Username, + StoreUserEmail: loginData.Email, + } + overlord := c.d.overlord st := overlord.State() theStore := storeFrom(c.d) macaroon, discharge, err := theStore.LoginUser(loginData.Email, loginData.Password, loginData.Otp) switch err { case store.ErrAuthenticationNeeds2fa: - return &apiError{ + return loginError(&apiError{ Status: 401, Message: err.Error(), Kind: client.ErrorKindTwoFactorRequired, - } + }, snapdUser, seclog.ReasonTwoFactorRequired) case store.Err2faFailed: - return &apiError{ + return loginError(&apiError{ Status: 401, Message: err.Error(), Kind: client.ErrorKindTwoFactorFailed, - } + }, snapdUser, seclog.ReasonTwoFactorFailed) default: switch err := err.(type) { case store.InvalidAuthDataError: - return &apiError{ + return loginError(&apiError{ Status: 400, Message: err.Error(), Kind: client.ErrorKindInvalidAuthData, Value: err, - } + }, snapdUser, seclog.ReasonInvalidAuthData) case store.PasswordPolicyError: - return &apiError{ + return loginError(&apiError{ Status: 401, Message: err.Error(), Kind: client.ErrorKindPasswordPolicy, Value: err, - } + }, snapdUser, seclog.ReasonPasswordPolicy) + } + reason := seclog.ReasonInternal + if err == store.ErrInvalidCredentials { + reason = seclog.ReasonInvalidCredentials } - return Unauthorized(err.Error()) + return loginError(Unauthorized(err.Error()), snapdUser, reason) case nil: // continue } @@ -172,9 +199,15 @@ func loginUser(c *Command, r *http.Request, user *auth.UserState) Response { } st.Unlock() if err != nil { - return InternalError("cannot persist authentication details: %v", err) + return loginError(InternalError("cannot persist authentication details: %v", err), snapdUser, seclog.ReasonInternal) } + snapdUser.ID = int64(user.ID) + snapdUser.SystemUserName = user.Username + snapdUser.StoreUserEmail = user.Email + snapdUser.Expiration = user.Expiration + seclogLogLoginSuccess(snapdUser) + result := userResponseData{ ID: user.ID, Username: user.Username, diff --git a/daemon/api_users_test.go b/daemon/api_users_test.go index 28a8b729002..2e0ef43b241 100644 --- a/daemon/api_users_test.go +++ b/daemon/api_users_test.go @@ -39,6 +39,7 @@ import ( "github.com/snapcore/snapd/overlord/devicestate/devicestatetest" "github.com/snapcore/snapd/overlord/state" "github.com/snapcore/snapd/release" + "github.com/snapcore/snapd/seclog" "github.com/snapcore/snapd/store" "github.com/snapcore/snapd/testutil" ) @@ -113,6 +114,11 @@ func (s *userSuite) TestLoginUser(c *check.C) { s.expectLoginAccess() + var loggedUser seclog.SnapdUser + s.AddCleanup(daemon.MockSeclogLogLoginSuccess(func(user seclog.SnapdUser) { + loggedUser = user + })) + s.loginUserStoreMacaroon = "user-macaroon" s.loginUserDischarge = "the-discharge-macaroon-serialized-data" buf := bytes.NewBufferString(`{"username": "email@.com", "password": "password"}`) @@ -149,6 +155,10 @@ func (s *userSuite) TestLoginUser(c *check.C) { c.Check(err, check.IsNil) c.Check(snapdMacaroon.Id(), check.Equals, "1") c.Check(snapdMacaroon.Location(), check.Equals, "snapd") + + // security log was called with the right user details + c.Check(loggedUser.ID, check.Equals, int64(1)) + c.Check(loggedUser.StoreUserEmail, check.Equals, "email@.com") } func (s *userSuite) TestLoginUserWithUsername(c *check.C) { @@ -156,6 +166,11 @@ func (s *userSuite) TestLoginUserWithUsername(c *check.C) { s.expectLoginAccess() + var loggedUser seclog.SnapdUser + s.AddCleanup(daemon.MockSeclogLogLoginSuccess(func(user seclog.SnapdUser) { + loggedUser = user + })) + s.loginUserStoreMacaroon = "user-macaroon" s.loginUserDischarge = "the-discharge-macaroon-serialized-data" buf := bytes.NewBufferString(`{"username": "username", "email": "email@.com", "password": "password"}`) @@ -191,6 +206,11 @@ func (s *userSuite) TestLoginUserWithUsername(c *check.C) { c.Check(err, check.IsNil) c.Check(snapdMacaroon.Id(), check.Equals, "1") c.Check(snapdMacaroon.Location(), check.Equals, "snapd") + + // security log was called with the right user details + c.Check(loggedUser.ID, check.Equals, int64(1)) + c.Check(loggedUser.SystemUserName, check.Equals, "username") + c.Check(loggedUser.StoreUserEmail, check.Equals, "email@.com") } func (s *userSuite) TestLoginUserNoEmailWithExistentLocalUser(c *check.C) { @@ -405,6 +425,13 @@ func (s *userSuite) TestLoginUserDeveloperAPIError(c *check.C) { func (s *userSuite) TestLoginUserTwoFactorRequiredError(c *check.C) { s.expectLoginAccess() + var loggedUser seclog.SnapdUser + var loggedReason seclog.Reason + s.AddCleanup(daemon.MockSeclogLogLoginFailure(func(user seclog.SnapdUser, reason seclog.Reason) { + loggedUser = user + loggedReason = reason + })) + s.err = store.ErrAuthenticationNeeds2fa buf := bytes.NewBufferString(`{"username": "email@.com", "password": "password"}`) req, err := http.NewRequest("POST", "/v2/login", buf) @@ -413,11 +440,21 @@ func (s *userSuite) TestLoginUserTwoFactorRequiredError(c *check.C) { rspe := s.errorReq(c, req, nil, actionIsExpected) c.Check(rspe.Status, check.Equals, 401) c.Check(rspe.Kind, check.Equals, client.ErrorKindTwoFactorRequired) + + c.Check(loggedUser.StoreUserEmail, check.Equals, "email@.com") + c.Check(loggedReason.Code, check.Equals, seclog.ReasonTwoFactorRequired) } func (s *userSuite) TestLoginUserTwoFactorFailedError(c *check.C) { s.expectLoginAccess() + var loggedUser seclog.SnapdUser + var loggedReason seclog.Reason + s.AddCleanup(daemon.MockSeclogLogLoginFailure(func(user seclog.SnapdUser, reason seclog.Reason) { + loggedUser = user + loggedReason = reason + })) + s.err = store.Err2faFailed buf := bytes.NewBufferString(`{"username": "email@.com", "password": "password"}`) req, err := http.NewRequest("POST", "/v2/login", buf) @@ -426,11 +463,21 @@ func (s *userSuite) TestLoginUserTwoFactorFailedError(c *check.C) { rspe := s.errorReq(c, req, nil, actionIsExpected) c.Check(rspe.Status, check.Equals, 401) c.Check(rspe.Kind, check.Equals, client.ErrorKindTwoFactorFailed) + + c.Check(loggedUser.StoreUserEmail, check.Equals, "email@.com") + c.Check(loggedReason.Code, check.Equals, seclog.ReasonTwoFactorFailed) } func (s *userSuite) TestLoginUserInvalidCredentialsError(c *check.C) { s.expectLoginAccess() + var loggedUser seclog.SnapdUser + var loggedReason seclog.Reason + s.AddCleanup(daemon.MockSeclogLogLoginFailure(func(user seclog.SnapdUser, reason seclog.Reason) { + loggedUser = user + loggedReason = reason + })) + s.err = store.ErrInvalidCredentials buf := bytes.NewBufferString(`{"username": "email@.com", "password": "password"}`) req, err := http.NewRequest("POST", "/v2/login", buf) @@ -439,6 +486,10 @@ func (s *userSuite) TestLoginUserInvalidCredentialsError(c *check.C) { rspe := s.errorReq(c, req, nil, actionIsExpected) c.Check(rspe.Status, check.Equals, 401) c.Check(rspe.Message, check.Equals, "invalid credentials") + + c.Check(loggedUser.StoreUserEmail, check.Equals, "email@.com") + c.Check(loggedReason.Code, check.Equals, seclog.ReasonInvalidCredentials) + c.Check(loggedReason.Message, check.Equals, "invalid credentials") } func (s *userSuite) TestLoginUserInvalidAuthDataError(c *check.C) { @@ -469,6 +520,34 @@ func (s *userSuite) TestLoginUserPasswordPolicyError(c *check.C) { c.Check(rspe.Value, check.DeepEquals, s.err) } +func (s *userSuite) TestLoginUserPersistError(c *check.C) { + s.expectLoginAccess() + + var loggedUser seclog.SnapdUser + var loggedReason seclog.Reason + s.AddCleanup(daemon.MockSeclogLogLoginFailure(func(user seclog.SnapdUser, reason seclog.Reason) { + loggedUser = user + loggedReason = reason + })) + + s.loginUserStoreMacaroon = "user-macaroon" + s.loginUserDischarge = "the-discharge-macaroon-serialized-data" + buf := bytes.NewBufferString(`{"username": "username", "email": "email@.com", "password": "password"}`) + req, err := http.NewRequest("POST", "/v2/login", buf) + c.Assert(err, check.IsNil) + + // Pass a user whose ID does not exist in the auth state, so + // auth.UpdateUser returns ErrInvalidUser. + fakeUser := &auth.UserState{ID: 99999, Username: "username", Email: "email@.com"} + rspe := s.errorReq(c, req, fakeUser, actionIsExpected) + c.Check(rspe.Status, check.Equals, 500) + c.Check(rspe.Message, check.Matches, "cannot persist authentication details: .*") + + c.Check(loggedUser.StoreUserEmail, check.Equals, "email@.com") + c.Check(loggedUser.SystemUserName, check.Equals, "username") + c.Check(loggedReason.Message, check.Matches, "cannot persist authentication details: .*") +} + func (s *userSuite) TestPostCreateUser(c *check.C) { s.testCreateUser(c, true) } diff --git a/daemon/export_api_users_test.go b/daemon/export_api_users_test.go index 6a4aff33dec..1c032364140 100644 --- a/daemon/export_api_users_test.go +++ b/daemon/export_api_users_test.go @@ -25,6 +25,7 @@ import ( "github.com/snapcore/snapd/overlord/auth" "github.com/snapcore/snapd/overlord/devicestate" "github.com/snapcore/snapd/overlord/state" + "github.com/snapcore/snapd/seclog" "github.com/snapcore/snapd/testutil" ) @@ -52,6 +53,14 @@ func MockDeviceStateRemoveUser(removeUser func(st *state.State, username string, return restore } +func MockSeclogLogLoginSuccess(f func(user seclog.SnapdUser)) (restore func()) { + return testutil.Mock(&seclogLogLoginSuccess, f) +} + +func MockSeclogLogLoginFailure(f func(user seclog.SnapdUser, reason seclog.Reason)) (restore func()) { + return testutil.Mock(&seclogLogLoginFailure, f) +} + type ( UserResponseData = userResponseData ) diff --git a/interfaces/builtin/content_test.go b/interfaces/builtin/content_test.go index 1e9c70387c6..793f9a2942e 100644 --- a/interfaces/builtin/content_test.go +++ b/interfaces/builtin/content_test.go @@ -20,6 +20,7 @@ package builtin_test import ( + "fmt" "path/filepath" "strings" @@ -322,6 +323,46 @@ apps: c.Assert(interfaces.BeforePreparePlug(s.iface, plug), ErrorMatches, "content plug must contain target path") } +func (s *ContentSuite) TestSanitizePlugTargetEdgeCases(c *C) { + const snapYamlTemplate = `name: content-snap +version: 1.0 +plugs: + content-plug: + interface: content + content: mycont + target: %s +` + for _, tc := range []struct { + target string + errMsg string + }{ + // explicit, well understood, covered by other unit tests + {target: "$SNAP/import"}, + {target: "$SNAP_DATA/import"}, + {target: "$SNAP_COMMON/import"}, + // bare $SNAP, no subpath + {target: "$SNAP"}, + // bare path, implicit $SNAP prefix + {target: "import"}, + // absolute path, implicit $SNAP prefix + {target: "/import"}, + // bare root, ends up as $SNAP + {target: "/"}, + + // trailing slash is not a clean path, inconsistent with the rest + {target: "$SNAP/", errMsg: `content interface path is not clean: .*`}, + } { + info := snaptest.MockInfo(c, fmt.Sprintf(snapYamlTemplate, tc.target), nil) + plug := info.Plugs["content-plug"] + err := interfaces.BeforePreparePlug(s.iface, plug) + if tc.errMsg == "" { + c.Assert(err, IsNil, Commentf("target: %s", tc.target)) + } else { + c.Assert(err, ErrorMatches, tc.errMsg, Commentf("target: %s", tc.target)) + } + } +} + func (s *ContentSuite) TestSanitizeSlotNilAttrMap(c *C) { const mockSnapYaml = `name: content-slot-snap version: 1.0 diff --git a/osutil/mountinfo_linux_test.go b/osutil/mountinfo_linux_test.go index 4746f721b85..4e454a88b4a 100644 --- a/osutil/mountinfo_linux_test.go +++ b/osutil/mountinfo_linux_test.go @@ -157,6 +157,45 @@ func (s *mountinfoSuite) TestParseMountInfoEntry5(c *C) { c.Assert(entry.MountDir, Equals, "/tmp/strange\rdir") } +// TestParseMountInfoEntryBrokenOctalEscaping checks that partial octal escape +// sequences (fewer than 3 octal digits after a backslash, including a trailing +// backslash) do not cause a panic and are preserved verbatim, consistent with +// the behaviour of the C mountinfo parser. +func (s *mountinfoSuite) TestParseMountInfoEntryBrokenOctalEscaping(c *C) { + // Non-octal chars after backslash and trailing backslash in last field. + entry, err := osutil.ParseMountInfoEntry( + `2074 27 0:54 / /tmp/strange-dir rw,relatime shared:1039 - tmpfs no\888thing rw\`) + c.Assert(err, IsNil) + c.Assert(entry.MountSource, Equals, `no\888thing`) + c.Assert(entry.SuperOptions, DeepEquals, map[string]string{`rw\`: ""}) + + // Backslash followed by one octal digit at end of string. + entry, err = osutil.ParseMountInfoEntry( + `2074 27 0:54 / /tmp/dir rw - tmpfs source rw\0`) + c.Assert(err, IsNil) + c.Assert(entry.SuperOptions, DeepEquals, map[string]string{`rw\0`: ""}) + + // Backslash followed by two octal digits at end of string. + entry, err = osutil.ParseMountInfoEntry( + `2074 27 0:54 / /tmp/dir rw - tmpfs source rw\05`) + c.Assert(err, IsNil) + c.Assert(entry.SuperOptions, DeepEquals, map[string]string{`rw\05`: ""}) + + // Backslash followed by one octal digit in mount source (field ended by space). + entry, err = osutil.ParseMountInfoEntry( + `2074 27 0:54 / /tmp/dir rw - tmpfs source\5 rw`) + c.Assert(err, IsNil) + c.Assert(entry.MountSource, Equals, `source\5`) + c.Assert(entry.SuperOptions, DeepEquals, map[string]string{"rw": ""}) + + // Backslash followed by two octal digits in mount source. + entry, err = osutil.ParseMountInfoEntry( + `2074 27 0:54 / /tmp/dir rw - tmpfs source\57 rw`) + c.Assert(err, IsNil) + c.Assert(entry.MountSource, Equals, `source\57`) + c.Assert(entry.SuperOptions, DeepEquals, map[string]string{"rw": ""}) +} + // Test that empty mountinfo is parsed without errors. func (s *mountinfoSuite) TestReadMountInfo1(c *C) { entries, err := osutil.ReadMountInfo(strings.NewReader("")) diff --git a/overlord/confdbstate/confdbmgr.go b/overlord/confdbstate/confdbmgr.go index 7146beca59e..9529fbd7ec6 100644 --- a/overlord/confdbstate/confdbmgr.go +++ b/overlord/confdbstate/confdbmgr.go @@ -35,9 +35,14 @@ import ( ) const ( - // cacheKeyPrefix is the prefix to be concatenated with confdb IDs to form a - // cache key used to store pending access data. - cacheKeyPrefix = "confdb-accesses-" + // pendingCachePrefix is the prefix to be concatenated with confdb IDs to + // form a cache key used to store pending access data. + pendingCachePrefix = "pending-confdb-" + + // schedulingCachePrefix is the prefix to be concatenated with confdb IDs to + // form a cache key used to store access data that was unblocked and is being + // scheduled. + schedulingCachePrefix = "scheduling-confdb-" ) func setupConfdbHook(st *state.State, snapName, hookName string, ignoreError bool) *state.Task { @@ -105,6 +110,46 @@ func (m *ConfdbManager) doCommitTransaction(t *state.Task, _ *tomb.Tomb) (err er } schema := confdbAssert.Schema().DatabagSchema + hasSaveViewHook := false + for _, task := range t.Change().Tasks() { + if task.Kind() != "run-hook" { + continue + } + + var hooksup hookstate.HookSetup + err := task.Get("hook-setup", &hooksup) + if err != nil { + return fmt.Errorf(`internal error: cannot get "hook-setup" from run-hook task: %w`, err) + } + + if strings.HasPrefix(hooksup.Hook, "save-view-") { + hasSaveViewHook = true + break + } + } + + // we error early if a write may affect ephemeral data but no save-view hook + // is present. However, a change-view hook may have written to an ephemeral + // path after that so we have to check again + if !hasSaveViewHook { + var viewName string + err = t.Get("view", &viewName) + if err != nil { + return fmt.Errorf(`internal error: cannot get "view" from task: %w`, err) + } + + view := confdbAssert.Schema().View(viewName) + paths := tx.AlteredPaths() + mightAffectEph, err := view.WriteAffectsEphemeral(paths) + if err != nil { + return fmt.Errorf("cannot commit transaction: cannot check for ephemeral paths: %v", err) + } + + if mightAffectEph { + return fmt.Errorf("cannot commit transaction: write may affect ephemeral data but no save-view hook is present") + } + } + return tx.Commit(st, schema) } @@ -123,7 +168,6 @@ func (m *ConfdbManager) clearOngoingTransaction(t *state.Task, _ *tomb.Tomb) err return err } - // TODO: unblock next waiting confdb writer once we add the blocking logic return nil } @@ -202,9 +246,13 @@ type confdbTransactions struct { ReadTxIDs []string `json:"read-tx-ids,omitempty"` WriteTxID string `json:"write-tx-id,omitempty"` - // pending holds accesses that are waiting to be scheduled. It's read from + // Pending holds accesses that are waiting to be scheduled. It's read from // the state cache so it's only kept in-memory, never persisted into state. - pending []pendingAccess + Pending []access `json:"-"` + + // Scheduling holds accesses that have been unblocked (moved from pending) + // but have not yet finished scheduling tasks/exiting. + Scheduling []access `json:"-"` } // CanStartReadTx returns true if there isn't a write transaction running or @@ -214,7 +262,10 @@ func (txs *confdbTransactions) CanStartReadTx() bool { return false } - for _, access := range txs.pending { + accesses := append([]access{}, txs.Pending...) + accesses = append(accesses, txs.Scheduling...) + + for _, access := range accesses { if access.AccessType == writeAccess { return false } @@ -223,21 +274,32 @@ func (txs *confdbTransactions) CanStartReadTx() bool { return true } -// CanStartWriteTx returns true if there is no running or pending transaction. +// CanStartWriteTx returns true if there is no access currently running or +// waiting to run. func (txs *confdbTransactions) CanStartWriteTx() bool { - return txs.WriteTxID == "" && len(txs.ReadTxIDs) == 0 && len(txs.pending) == 0 + return txs.WriteTxID == "" && len(txs.ReadTxIDs) == 0 && + len(txs.Pending) == 0 && len(txs.Scheduling) == 0 } // addReadTransaction adds a read transaction for the specified confdb, if no -// write transactions is ongoing. The state must be locked by the caller. -func addReadTransaction(st *state.State, account, confdbName, id string) error { +// write transactions is ongoing. If a accessID is passed in, it'll be removed +// from the Scheduling list. The state must be locked by the caller. +func addReadTransaction(st *state.State, account, confdbName, id, accessID string) error { txs, updateTxStateFunc, err := getOngoingTxs(st, account, confdbName) if err != nil { return err } + for i, acc := range txs.Scheduling { + if acc.ID == accessID { + txs.Scheduling = append(txs.Scheduling[:i], txs.Scheduling[i+1:]...) + break + } + } + if txs.WriteTxID != "" { - return fmt.Errorf("cannot read confdb (%s/%s): a write transaction is ongoing", account, confdbName) + // shouldn't happen save for programmer error + return fmt.Errorf("internal error: cannot read confdb (%s/%s): a write transaction is ongoing", account, confdbName) } txs.ReadTxIDs = append(txs.ReadTxIDs, id) @@ -246,21 +308,30 @@ func addReadTransaction(st *state.State, account, confdbName, id string) error { } // setWriteTransaction sets a write transaction for the specified confdb schema, -// if no other transactions (read or write) are ongoing. The state must be locked -// by the caller. -func setWriteTransaction(st *state.State, account, schemaName, id string) error { +// if no other transactions (read or write) are ongoing. If a accessID is passed +// in, it'll be removed from the Scheduling list. The state must be locked by +// the caller. +func setWriteTransaction(st *state.State, account, schemaName, id, accessID string) error { txs, updateTxStateFunc, err := getOngoingTxs(st, account, schemaName) if err != nil { return err } + for i, acc := range txs.Scheduling { + if acc.ID == accessID { + txs.Scheduling = append(txs.Scheduling[:i], txs.Scheduling[i+1:]...) + break + } + } + if txs.WriteTxID != "" || len(txs.ReadTxIDs) != 0 { op := "read" if txs.WriteTxID != "" { op = "write" } - return fmt.Errorf("cannot write confdb (%s/%s): a %s transaction is ongoing", account, schemaName, op) + // shouldn't happen save for programmer error + return fmt.Errorf("internal error: cannot write confdb (%s/%s): a %s transaction is ongoing", account, schemaName, op) } txs.WriteTxID = id @@ -301,16 +372,35 @@ func getOngoingTxs(st *state.State, account, schemaName string) (ongoingTxs *con st.Set("confdb-ongoing-txs", confdbTxs) } - st.Cache(cacheKeyPrefix+ref, ongoingTxs.pending) + if len(ongoingTxs.Pending) == 0 { + st.Cache(pendingCachePrefix+ref, nil) + } else { + st.Cache(pendingCachePrefix+ref, ongoingTxs.Pending) + } + + if len(ongoingTxs.Scheduling) == 0 { + st.Cache(schedulingCachePrefix+ref, nil) + } else { + st.Cache(schedulingCachePrefix+ref, ongoingTxs.Scheduling) + } } - cached := st.Cached(cacheKeyPrefix + ref) + cached := st.Cached(pendingCachePrefix + ref) if cached != nil { - queue, ok := cached.([]pendingAccess) + queue, ok := cached.([]access) if !ok { return nil, nil, fmt.Errorf("internal error: cannot access confdb pending transaction queue") } - confdbTxs[ref].pending = queue + confdbTxs[ref].Pending = queue + } + + cached = st.Cached(schedulingCachePrefix + ref) + if cached != nil { + queue, ok := cached.([]access) + if !ok { + return nil, nil, fmt.Errorf("internal error: cannot access confdb scheduling list") + } + confdbTxs[ref].Scheduling = queue } return confdbTxs[ref], updateTxStateFunc, nil @@ -338,34 +428,50 @@ func unsetOngoingTransaction(st *state.State, account, schemaName, id string) er if len(txs.ReadTxIDs) > 0 { // there are other transactions running (can only be reads) so skip this. - // The last one will unblock the next access + // The last one will unblock the next accesses return nil } - // unblock any waiting routine - if len(txs.pending) > 0 { - logger.Debugf("remove pending access %s", txs.pending[0].ID) - close(txs.pending[0].WaitChan) - } - - return nil + return maybeUnblockAccesses(txs) } -func unblockNextAccess(st *state.State, account, schemaName string) error { - txs, updateTxStateFunc, err := getOngoingTxs(st, account, schemaName) - if err != nil { - return err +// maybeUnblockAccesses unblocks as many consecutive pending accesses as +// possible, either one write or one or more sequential reads. +// This may be a no-op, if there are: +// - no pending changes (i.e., there's nothing to unblock) +// - changes running for other transactions - pending accesses would've been +// scheduled w/o waiting if they could (see waitForAccess) so any pending +// accesses are guaranteed to be incompatible. +// - accesses that have been unblocked but are still scheduling changes. If we +// unblocked accesses here, they would race with the ones already scheduling +// +// If accesses are unblocked, they're removed from the Pending list and put into +// the Scheduling list so we can track unblocked but still unscheduled accesses. +func maybeUnblockAccesses(txs *confdbTransactions) error { + if len(txs.Pending) == 0 || txs.WriteTxID != "" || len(txs.ReadTxIDs) > 0 || len(txs.Scheduling) != 0 { + return nil } - if len(txs.pending) == 0 { - return nil + var upTo int + for i, acc := range txs.Pending { + if acc.AccessType == writeAccess { + if i == 0 { + acc.WaitChan <- struct{}{} + logger.Debugf("unblocking pending %s access %s", acc.AccessType, acc.ID) + upTo = i + } + + break + } + + acc.WaitChan <- struct{}{} + logger.Debugf("unblocking pending %s access %s", acc.AccessType, acc.ID) + upTo = i } - // unblock any waiting routine - logger.Debugf("remove pending access %s", txs.pending[0].ID) - close(txs.pending[0].WaitChan) + txs.Scheduling = append([]access{}, txs.Pending[:upTo+1]...) + txs.Pending = txs.Pending[upTo+1:] - updateTxStateFunc(txs) return nil } diff --git a/overlord/confdbstate/confdbmgr_test.go b/overlord/confdbstate/confdbmgr_test.go index c1c60861aad..a81bcba62ab 100644 --- a/overlord/confdbstate/confdbmgr_test.go +++ b/overlord/confdbstate/confdbmgr_test.go @@ -19,6 +19,7 @@ package confdbstate_test import ( + "context" "errors" "strings" "time" @@ -33,6 +34,7 @@ import ( "github.com/snapcore/snapd/overlord/ifacestate/ifacerepo" "github.com/snapcore/snapd/overlord/state" "github.com/snapcore/snapd/testutil" + "gopkg.in/tomb.v2" . "gopkg.in/check.v1" ) @@ -364,10 +366,14 @@ func (s *confdbTestSuite) TestSetAndUnsetOngoingTransactionHelpers(c *C) { err := s.state.Get("confdb-ongoing-txs", &ongoingTxs) c.Assert(err, testutil.ErrorIs, &state.NoStateError{}) - err = confdbstate.SetWriteTransaction(s.state, "my-acc", "my-confdb", "1") + s.state.Cache("scheduling-confdb-my-acc/my-confdb", []confdbstate.Access{{ID: "foo"}}) + + err = confdbstate.SetWriteTransaction(s.state, "my-acc", "my-confdb", "1", "foo") c.Assert(err, IsNil) + accs := s.state.Cached("scheduling-confdb-my-acc/my-confdb") + c.Assert(accs, IsNil) - err = confdbstate.SetWriteTransaction(s.state, "other-acc", "other-confdb", "2") + err = confdbstate.SetWriteTransaction(s.state, "other-acc", "other-confdb", "2", "") c.Assert(err, IsNil) err = s.state.Get("confdb-ongoing-txs", &ongoingTxs) @@ -400,29 +406,32 @@ func (s *confdbTestSuite) TestConflictingOngoingTransactions(c *C) { s.state.Lock() defer s.state.Unlock() - err := confdbstate.SetWriteTransaction(s.state, "my-acc", "my-confdb", "1") + err := confdbstate.SetWriteTransaction(s.state, "my-acc", "my-confdb", "1", "") c.Assert(err, IsNil) // can't set write due to ongoing write - err = confdbstate.SetWriteTransaction(s.state, "my-acc", "my-confdb", "2") - c.Assert(err, ErrorMatches, `cannot write confdb \(my-acc/my-confdb\): a write transaction is ongoing`) + err = confdbstate.SetWriteTransaction(s.state, "my-acc", "my-confdb", "2", "") + c.Assert(err, ErrorMatches, `internal error: cannot write confdb \(my-acc/my-confdb\): a write transaction is ongoing`) // can't add read due to ongoing write - err = confdbstate.AddReadTransaction(s.state, "my-acc", "my-confdb", "2") - c.Assert(err, ErrorMatches, `cannot read confdb \(my-acc/my-confdb\): a write transaction is ongoing`) + err = confdbstate.AddReadTransaction(s.state, "my-acc", "my-confdb", "2", "") + c.Assert(err, ErrorMatches, `internal error: cannot read confdb \(my-acc/my-confdb\): a write transaction is ongoing`) err = confdbstate.UnsetOngoingTransaction(s.state, "my-acc", "my-confdb", "1") c.Assert(err, IsNil) - err = confdbstate.AddReadTransaction(s.state, "my-acc", "my-confdb", "1") + s.state.Cache("scheduling-confdb-my-acc/my-confdb", []confdbstate.Access{{ID: "foo"}}) + err = confdbstate.AddReadTransaction(s.state, "my-acc", "my-confdb", "1", "foo") c.Assert(err, IsNil) + accs := s.state.Cached("scheduling-confdb-my-acc/my-confdb") + c.Assert(accs, IsNil) // can't set write due to ongoing read - err = confdbstate.SetWriteTransaction(s.state, "my-acc", "my-confdb", "2") - c.Assert(err, ErrorMatches, `cannot write confdb \(my-acc/my-confdb\): a read transaction is ongoing`) + err = confdbstate.SetWriteTransaction(s.state, "my-acc", "my-confdb", "2", "") + c.Assert(err, ErrorMatches, `internal error: cannot write confdb \(my-acc/my-confdb\): a read transaction is ongoing`) // many reads are fine - err = confdbstate.AddReadTransaction(s.state, "my-acc", "my-confdb", "2") + err = confdbstate.AddReadTransaction(s.state, "my-acc", "my-confdb", "2", "") c.Assert(err, IsNil) } @@ -442,6 +451,7 @@ func (s *confdbTestSuite) TestCommitTransaction(c *C) { c.Assert(err, IsNil) setTransaction(t, tx) + t.Set("view", "setup-wifi") s.state.Unlock() err = s.o.Settle(testutil.HostScaledTimeout(5 * time.Second)) @@ -480,7 +490,7 @@ func (s *confdbTestSuite) TestClearOngoingTransaction(c *C) { chg.AddTask(t) t.Set("tx-task", commitTask.ID()) - confdbstate.SetWriteTransaction(s.state, s.devAccID, "network", commitTask.ID()) + confdbstate.SetWriteTransaction(s.state, s.devAccID, "network", commitTask.ID(), "") c.Assert(err, IsNil) var confdbTxs map[string]*confdbstate.ConfdbTransactions @@ -518,9 +528,10 @@ func (s *confdbTestSuite) TestClearTransactionOnError(c *C) { err = tx.Set(parsePath(c, "foo"), "bar") c.Assert(err, IsNil) setTransaction(commitTask, tx) + commitTask.Set("view", "setup-wifi") // add this transaction to the state - err = confdbstate.SetWriteTransaction(s.state, s.devAccID, "network", commitTask.ID()) + err = confdbstate.SetWriteTransaction(s.state, s.devAccID, "network", commitTask.ID(), "") c.Assert(err, IsNil) s.state.Unlock() @@ -531,10 +542,68 @@ func (s *confdbTestSuite) TestClearTransactionOnError(c *C) { c.Assert(chg.Status(), Equals, state.ErrorStatus) c.Assert(commitTask.Status(), Equals, state.ErrorStatus) c.Assert(clearTask.Status(), Equals, state.UndoneStatus) - c.Assert(strings.Join(commitTask.Log(), "\n"), Matches, ".*ERROR cannot accept top level element: map contains unexpected key \"foo\"") + c.Assert(strings.Join(commitTask.Log(), "\n"), Matches, ".*ERROR cannot commit transaction: cannot check for ephemeral paths: cannot check if write affects ephemeral data: cannot use \"foo\" as key in map") // no ongoing confdb transaction var ongoingTxs map[string]*confdbstate.ConfdbTransactions err = s.state.Get("confdb-ongoing-txs", &ongoingTxs) c.Assert(err, testutil.ErrorIs, &state.NoStateError{}) } + +func (s *confdbTestSuite) TestCommitTransactionEphemeralCheckWithoutSaveViewHooks(c *C) { + s.state.Lock() + defer s.state.Unlock() + + // the custodian has a change-view hook but no save-view + custodians := map[string]confdbHooks{"custodian-snap": changeView} + s.setupConfdbScenario(c, custodians, nil) + + // mock a change-view hook that writes to ephemeral data + restore := hookstate.MockRunHook(func(ctx *hookstate.Context, _ *tomb.Tomb) ([]byte, error) { + t, _ := ctx.Task() + ctx.State().Lock() + defer ctx.State().Unlock() + + var hooksup *hookstate.HookSetup + err := t.Get("hook-setup", &hooksup) + if err != nil { + return nil, err + } + c.Assert(strings.HasPrefix(hooksup.Hook, "change-view-"), Equals, true) + + tx, _, saveChanges, err := confdbstate.GetStoredTransaction(t) + if err != nil { + return nil, err + } + + err = tx.Set(parsePath(c, "wifi.eph"), "ephemeral-from-hook") + if err != nil { + return nil, err + } + saveChanges() + + return nil, nil + }) + defer restore() + + view, err := confdbstate.GetView(s.state, s.devAccID, "network", "setup-wifi") + c.Assert(err, IsNil) + + chgID, err := confdbstate.WriteConfdb(context.Background(), s.state, view, map[string]any{"ssid": "my-wifi"}) + c.Assert(err, IsNil) + + chg := s.state.Change(chgID) + c.Assert(chg, NotNil) + + s.state.Unlock() + err = s.o.Settle(testutil.HostScaledTimeout(5 * time.Second)) + s.state.Lock() + c.Assert(err, IsNil) + + // commit fails because change-view hook wrote ephemeral data but no save-view hooks exist + c.Assert(chg.Status(), Equals, state.ErrorStatus) + + commitTask := findTask(chg, "commit-confdb-tx") + c.Assert(commitTask, NotNil) + c.Assert(strings.Join(commitTask.Log(), "\n"), Matches, `.*ERROR cannot commit transaction: write may affect ephemeral data but no save-view hook is present.*`) +} diff --git a/overlord/confdbstate/confdbstate.go b/overlord/confdbstate/confdbstate.go index cff60e1cd3a..505e895ba0c 100644 --- a/overlord/confdbstate/confdbstate.go +++ b/overlord/confdbstate/confdbstate.go @@ -45,21 +45,26 @@ import ( var ( assertstateConfdbSchema = assertstate.ConfdbSchema assertstateFetchConfdbSchemaAssertion = assertstate.FetchConfdbSchemaAssertion -) -var ( setConfdbChangeKind = swfeats.RegisterChangeKind("set-confdb") getConfdbChangeKind = swfeats.RegisterChangeKind("get-confdb") - // testBlockingChan is closed right before blocking to wait for access. - blockingSignalChan chan struct{} + // blockingSignals holds channels that, if present, will be closed to signal + // that an operation is about to block. Its only use is to test some blocking + // behaviour. + blockingSignals map[string]chan struct{} defaultWaitTimeout = 10 * time.Minute + + ensureNow = func(st *state.State) { + st.EnsureBefore(0) + } + + transactionTimeout = 2 * time.Minute ) -// SetViaView uses the view to set the requests in the transaction's databag. -// TODO: unexport this once the next PR refactors the writing from snapctl -func SetViaView(bag confdb.Databag, view *confdb.View, requests map[string]any) error { +// setViaView uses the view to set the requests in the transaction's databag. +func setViaView(bag confdb.Databag, view *confdb.View, requests map[string]any) error { for request, value := range requests { var err error if value == nil { @@ -201,142 +206,144 @@ var writeDatabag = func(st *state.State, databag confdb.JSONDatabag, account, db return nil } -// waitForAccess blocks until the access can be processed or until the context -// was cancelled/timed out, in which case an error is returned. Caller must hold -// the state lock. -func waitForAccess(ctx context.Context, st *state.State, view *confdb.View, access accessType) (err error) { +// waitForAccess checks if ongoing transactions prevent this access from running +// and if necessary blocks until it can. The following scenarios can occur: +// - the access can immediately run (no ongoing tx or all are reads) - returns +// without waiting, with no accessID or error +// - the access must wait - returns after being unblocked, with a non-empty +// accessID matching an access in Processing (to be removed after scheduling) +// - any error occurs or the context times out or is cancelled - returns an +// error but no accessID, since relevant state in Processing/Pending is cleared +// +// Caller must hold the state lock. +func waitForAccess(ctx context.Context, st *state.State, view *confdb.View, accKind accessType) (accessID string, err error) { account, schema := view.Schema().Account, view.Schema().Name txs, updateTxs, err := getOngoingTxs(st, account, schema) if err != nil { - return fmt.Errorf("cannot access confdb view %s: cannot check ongoing transactions: %v", view.ID(), err) + return "", fmt.Errorf("cannot access confdb view %s: cannot check ongoing transactions: %v", view.ID(), err) } - if (access == readAccess && txs.CanStartReadTx()) || (access == writeAccess && txs.CanStartWriteTx()) { - return nil + if (accKind == readAccess && txs.CanStartReadTx()) || (accKind == writeAccess && txs.CanStartWriteTx()) { + return "", nil } - id := randutil.RandomString(20) + accessID = randutil.RandomString(20) - wait := make(chan struct{}) - txs.pending = append(txs.pending, pendingAccess{ - AccessType: access, + // AFAICT a buffer isn't strictly necessary here because if a writer sends to + // the channel, this goroutine will already have unlocked state and will eventually + // read from the channel, unblocking the lock holding goroutine. But let's be extra safe + wait := make(chan struct{}, 2) + txs.Pending = append(txs.Pending, access{ + AccessType: accKind, WaitChan: wait, - ID: id, + ID: accessID, }) updateTxs(txs) st.Unlock() - defer func() { - st.Lock() - txs, updateTxs, defErr := getOngoingTxs(st, account, schema) - if defErr != nil { - if err == nil { - err = fmt.Errorf("cannot access %s: cannot check ongoing transactions: %v", view.ID(), defErr) - } - return - } - - accIndex := -1 - for i, acc := range txs.pending { - if acc.ID == id { - accIndex = i - } - } - - if accIndex == -1 { - logger.Noticef("cannot find access id %s when updating pending accesses", id) - } else { - txs.pending = append(txs.pending[:accIndex], txs.pending[accIndex+1:]...) - } - - updateTxs(txs) - }() - - _, set := ctx.Deadline() - if !set { + if _, set := ctx.Deadline(); !set { // set a maximum waiting time to safeguard against this hanging forever var cancel context.CancelFunc ctx, cancel = context.WithTimeout(ctx, defaultWaitTimeout) defer cancel() } - if blockingSignalChan != nil { - // signal we're about to block for testing - close(blockingSignalChan) + if blockingSignals["wait-for-access"] != nil { + // for testing purposes only + close(blockingSignals["wait-for-access"]) } select { case <-wait: + st.Lock() case <-ctx.Done(): - return fmt.Errorf("cannot %s %s: timed out waiting for access", access, view.ID()) + // if the waiting was cancelled or timed out, clean up the pending state + st.Lock() + txs, updateTxs, err := getOngoingTxs(st, account, schema) + if err != nil { + return "", fmt.Errorf("cannot cleanup state after timeout/cancel: %v", err) + } + + for i, acc := range txs.Pending { + if acc.ID == accessID { + txs.Pending = append(txs.Pending[:i], txs.Pending[i+1:]...) + break + } + } + + // if the timeout/cancel raced with an unblock, the access might be in + // Scheduling so remove that + for i, acc := range txs.Scheduling { + if acc.ID == accessID { + txs.Scheduling = append(txs.Scheduling[:i], txs.Scheduling[i+1:]...) + break + } + } + + err = maybeUnblockAccesses(txs) + if err != nil { + return "", fmt.Errorf("cannot cleanup state after timeout/cancel: %v", err) + } + + updateTxs(txs) + + return "", fmt.Errorf("cannot %s %s: timed out waiting for access", accKind, view.ID()) } - return nil + return accessID, nil } // WriteConfdb takes a map of request paths to values, schedules a change to // set the values in specified confdb view and run the appropriate hooks. // Returns a change ID. func WriteConfdb(ctx context.Context, st *state.State, view *confdb.View, values map[string]any) (changeID string, err error) { - defer func() { - if err != nil { - uerr := unblockNextAccess(st, view.Schema().Account, view.Schema().Name) - if uerr != nil { - logger.Noticef("cannot unblock next access after failed write: %v", uerr) - } - } - }() - - err = waitForAccess(ctx, st, view, writeAccess) + accessID, err := waitForAccess(ctx, st, view, writeAccess) if err != nil { return "", err } - account, schemaName := view.Schema().Account, view.Schema().Name + + account, schema := view.Schema().Account, view.Schema().Name + // accessID is empty if we didn't release the lock and wait, so no state was + // modified and there aren't other accesses to unblock + if accessID != "" { + defer cleanupAccess(st, accessID, account, schema) + } // not running in an existing confdb hook context, so create a transaction // and a change to verify its changes and commit - tx, err := NewTransaction(st, account, schemaName) + tx, err := NewTransaction(st, account, schema) if err != nil { return "", fmt.Errorf("cannot modify confdb through view %s: cannot create transaction: %v", view.ID(), err) } - err = SetViaView(tx, view, values) + err = setViaView(tx, view, values) if err != nil { return "", err } // the hooks we schedule depend on the paths written so this must happen after writing - ts, err := createChangeConfdbTasks(st, tx, view, "") + ts, commitTask, _, err := createChangeConfdbTasks(st, tx, view, "") if err != nil { return "", err } - chg := st.NewChange(setConfdbChangeKind, fmt.Sprintf("Set confdb through %q", view.ID())) - chg.AddAll(ts) - - commitTask, err := ts.Edge(commitEdge) + err = setWriteTransaction(st, account, schema, commitTask.ID(), accessID) if err != nil { return "", err } - err = setWriteTransaction(st, account, schemaName, commitTask.ID()) - if err != nil { - return "", err - } + // schedule tasks after saving the tx ID so the deferred cleanup skips waking + // up waiters if a task will do it (txs.WriteTxID != "") + chg := st.NewChange(setConfdbChangeKind, fmt.Sprintf("Set confdb through %q", view.ID())) + chg.AddAll(ts) - return chg.ID(), err + return chg.ID(), nil } -type CommitTxFunc func() (changeID string, waitChan <-chan struct{}, err error) - -// GetTransactionToSet gets a transaction to change the confdb through the view. -// The state must be locked by the caller. Returns a transaction through which -// the confdb can be modified and a CommitTxFunc. The latter is called once the -// modifications are made to commit them. It will return a changeID and a channel, -// allowing the caller to block until commit. If a transaction was already ongoing, -// CommitTxFunc simply returns that without blocking (changes to it will be -// saved on ctx.Done()). -func GetTransactionToSet(hookCtx *hookstate.Context, st *state.State, view *confdb.View) (*Transaction, CommitTxFunc, error) { - account, schemaName := view.Schema().Account, view.Schema().Name +// WriteConfdbFromSnap takes a hook context and a map of requests to values that +// are written through the provided view. It will block until the writing change +// completes. +func WriteConfdbFromSnap(hookCtx *hookstate.Context, view *confdb.View, values map[string]any) (err error) { + account, schema := view.Schema().Account, view.Schema().Name // check if we're already running in the context of a committing transaction if IsConfdbHookCtx(hookCtx) { @@ -345,11 +352,11 @@ func GetTransactionToSet(hookCtx *hookstate.Context, st *state.State, view *conf t, _ := hookCtx.Task() tx, _, saveTxChanges, err := GetStoredTransaction(t) if err != nil { - return nil, nil, fmt.Errorf("cannot access confdb through view %s: cannot get transaction: %v", view.ID(), err) + return fmt.Errorf("cannot access confdb through view %s: cannot get transaction: %v", view.ID(), err) } - if tx.ConfdbAccount != account || tx.ConfdbName != schemaName { - return nil, nil, fmt.Errorf("cannot access confdb through view %s: ongoing transaction for %s/%s", view.ID(), tx.ConfdbAccount, tx.ConfdbName) + if tx.ConfdbAccount != account || tx.ConfdbName != schema { + return fmt.Errorf("cannot access confdb through view %s: ongoing transaction for %s/%s", view.ID(), tx.ConfdbAccount, tx.ConfdbName) } // update the commit task to save transaction changes made by the hook @@ -358,109 +365,109 @@ func GetTransactionToSet(hookCtx *hookstate.Context, st *state.State, view *conf return nil }) - return tx, nil, nil + return setViaView(tx, view, values) } - txs, _, err := getOngoingTxs(st, account, schemaName) + // get --wait-for timeout from context state, if any is set + ctx := context.Background() + if hookCtx.Timeout() != time.Duration(0) { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, hookCtx.Timeout()) + defer cancel() + } + + st := hookCtx.State() + accessID, err := waitForAccess(ctx, st, view, writeAccess) if err != nil { - return nil, nil, fmt.Errorf("cannot access confdb view %s: cannot check ongoing transactions: %v", view.ID(), err) + return err } - if txs != nil && !txs.CanStartWriteTx() { - // TODO: eventually we want to queue this write and block until we serve it. - // It might also be necessary to have some form of timeout. - return nil, nil, fmt.Errorf("cannot write confdb through view %s: ongoing transaction", view.ID()) + // accessID is empty if we didn't release the lock and wait, so no state was + // modified and there aren't other accesses to unblock + if accessID != "" { + defer cleanupAccess(st, accessID, account, schema) } // not running in an existing confdb hook context, so create a transaction // and a change to verify its changes and commit - tx, err := NewTransaction(st, account, schemaName) + tx, err := NewTransaction(st, account, schema) if err != nil { - return nil, nil, fmt.Errorf("cannot modify confdb through view %s: cannot create transaction: %v", view.ID(), err) + return fmt.Errorf("cannot modify confdb through view %s: cannot create transaction: %v", view.ID(), err) } - commitTx := func() (string, <-chan struct{}, error) { - var chg *state.Change - if hookCtx == nil || hookCtx.IsEphemeral() { - chg = st.NewChange(setConfdbChangeKind, fmt.Sprintf("Set confdb through %q", view.ID())) - } else { - // we're running in the context of a non-confdb hook, add the tasks to that change - task, _ := hookCtx.Task() - chg = task.Change() - } + err = setViaView(tx, view, values) + if err != nil { + return err + } - var callingSnap string - if hookCtx != nil { - callingSnap = hookCtx.InstanceName() - } + var chg *state.Change + if hookCtx.IsEphemeral() { + chg = st.NewChange(setConfdbChangeKind, fmt.Sprintf("Set confdb through %q", view.ID())) + } else { + // we're running in the context of a non-confdb hook, add the tasks to that change + task, _ := hookCtx.Task() + chg = task.Change() + } - ts, err := createChangeConfdbTasks(st, tx, view, callingSnap) - if err != nil { - return "", nil, err - } - chg.AddAll(ts) + ts, commitTask, clearTxTask, err := createChangeConfdbTasks(st, tx, view, hookCtx.InstanceName()) + if err != nil { + return err + } - commitTask, err := ts.Edge(commitEdge) - if err != nil { - return "", nil, err - } + // schedule tasks after saving the tx ID so the deferred cleanup skips waking + // up waiters if a task will do it (txs.WriteTxID != "") + err = setWriteTransaction(st, account, schema, commitTask.ID(), accessID) + if err != nil { + return err + } + chg.AddAll(ts) - clearTxTask, err := ts.Edge(clearTxEdge) - if err != nil { - return "", nil, err + waitChan := make(chan struct{}) + st.AddTaskStatusChangedHandler(func(t *state.Task, _, new state.Status) (remove bool) { + if t.ID() == clearTxTask.ID() && new.Ready() { + close(waitChan) + return true } + return false + }) - err = setWriteTransaction(st, account, schemaName, commitTask.ID()) - if err != nil { - return "", nil, err - } + ensureNow(st) - waitChan := make(chan struct{}) - st.AddTaskStatusChangedHandler(func(t *state.Task, old, new state.Status) (remove bool) { - if t.ID() == clearTxTask.ID() && new.Ready() { - close(waitChan) - return true - } - return false - }) + // wait for the transaction to be committed + hookCtx.Unlock() + defer hookCtx.Lock() - ensureNow(st) - return chg.ID(), waitChan, nil + if blockingSignals["wait-for-change-done"] != nil { + // for testing purposes only + close(blockingSignals["wait-for-change-done"]) } - return tx, commitTx, nil -} - -var ( - ensureNow = func(st *state.State) { - st.EnsureBefore(0) + select { + case <-waitChan: + case <-time.After(transactionTimeout): + return fmt.Errorf("cannot set confdb %s: timed out after %s", view.ID(), transactionTimeout) } - transactionTimeout = 2 * time.Minute -) - -const ( - commitEdge = state.TaskSetEdge("commit-edge") - clearTxEdge = state.TaskSetEdge("clear-tx-edge") -) + return nil +} -func createChangeConfdbTasks(st *state.State, tx *Transaction, view *confdb.View, callingSnap string) (*state.TaskSet, error) { +func createChangeConfdbTasks(st *state.State, tx *Transaction, view *confdb.View, callingSnap string) (ts *state.TaskSet, commitTask, clearTxTask *state.Task, err error) { custodians, custodianPlugs, err := getCustodianPlugsForView(st, view) if err != nil { - return nil, err + return nil, nil, nil, err } if len(custodianPlugs) == 0 { - return nil, fmt.Errorf("cannot commit changes to confdb made through view %s: no custodian snap connected", view.ID()) + return nil, nil, nil, fmt.Errorf("cannot commit changes to confdb made through view %s: no custodian snap connected", view.ID()) } paths := tx.AlteredPaths() mightAffectEph, err := view.WriteAffectsEphemeral(paths) if err != nil { - return nil, err + return nil, nil, nil, err } - ts := state.NewTaskSet() + ts = state.NewTaskSet() linkTask := func(t *state.Task) { tasks := ts.Tasks() if len(tasks) > 0 { @@ -492,7 +499,7 @@ func createChangeConfdbTasks(st *state.State, tx *Transaction, view *confdb.View } if hookPrefix == "save-view-" && mightAffectEph && !saveViewHookPresent { - return nil, fmt.Errorf("cannot access %s: write might change ephemeral data but no custodians has a save-view hook", view.ID()) + return nil, nil, nil, fmt.Errorf("cannot access %s: write might change ephemeral data but no custodians has a save-view hook", view.ID()) } } @@ -500,7 +507,7 @@ func createChangeConfdbTasks(st *state.State, tx *Transaction, view *confdb.View // changed with this data modification affectedPlugs, err := getPlugsAffectedByPaths(st, view.Schema(), paths) if err != nil { - return nil, err + return nil, nil, nil, err } viewChangedSnaps := make([]string, 0, len(affectedPlugs)) @@ -523,22 +530,22 @@ func createChangeConfdbTasks(st *state.State, tx *Transaction, view *confdb.View } // commit after custodians save ephemeral data - commitTask := st.NewTask("commit-confdb-tx", fmt.Sprintf("Commit changes to confdb (%s)", view.ID())) + commitTask = st.NewTask("commit-confdb-tx", fmt.Sprintf("Commit changes to confdb (%s)", view.ID())) commitTask.Set("confdb-transaction", tx) + commitTask.Set("view", view.Name) + // link all previous tasks to the commit task that carries the transaction for _, t := range ts.Tasks() { t.Set("tx-task", commitTask.ID()) } linkTask(commitTask) - ts.MarkEdge(commitTask, commitEdge) // clear the ongoing tx from the state and unblock other writers waiting for it - clearTxTask := st.NewTask("clear-confdb-tx", "Clears the ongoing confdb transaction from state") + clearTxTask = st.NewTask("clear-confdb-tx", "Clears the ongoing confdb transaction from state") linkTask(clearTxTask) clearTxTask.Set("tx-task", commitTask.ID()) - ts.MarkEdge(clearTxTask, clearTxEdge) - return ts, nil + return ts, commitTask, clearTxTask, nil } // getCustodianPlugsForView returns a list of snaps that have connected plugs @@ -663,7 +670,7 @@ func GetStoredTransaction(t *state.Task) (tx *Transaction, txTask *state.Task, s // IsConfdbHookCtx returns whether the hook context belongs to a confdb hook. func IsConfdbHookCtx(ctx *hookstate.Context) bool { - return ctx != nil && !ctx.IsEphemeral() && IsConfdbHookname(ctx.HookName()) + return !ctx.IsEphemeral() && IsConfdbHookname(ctx.HookName()) } // IsConfdbHookname returns whether the hookname denotes a confdb hook. @@ -676,21 +683,22 @@ func IsConfdbHookname(name string) bool { } // CanHookSetConfdb returns whether the hook context belongs to a confdb hook -// that supports snapctl set (either a write hook or load-view). +// that supports snapctl set (either a write hook or load-view). Returns false +// if the context is ephemeral. func CanHookSetConfdb(ctx *hookstate.Context) bool { - return ctx != nil && !ctx.IsEphemeral() && + return !ctx.IsEphemeral() && (strings.HasPrefix(ctx.HookName(), "change-view-") || strings.HasPrefix(ctx.HookName(), "query-view-") || strings.HasPrefix(ctx.HookName(), "load-view-")) } -// GetTransactionForSnapctlGet gets a transaction to read the view's confdb. It -// schedules tasks to load the confdb as needed, unless no custodian defined -// relevant hooks. Blocks until the confdb has been loaded into the Transaction. -// If no tasks need to run to load the confdb, returns without blocking. -func GetTransactionForSnapctlGet(hookCtx *hookstate.Context, view *confdb.View, paths []string, constraints map[string]any) (*Transaction, error) { +// ReadConfdbFromSnap gets a transaction to read the view's confdb. It schedules +// tasks to load the confdb as needed, unless no custodian defined relevant +// hooks. Blocks until the confdb has been loaded into the Transaction. If no +// tasks need to run to load the confdb, returns without blocking. +func ReadConfdbFromSnap(hookCtx *hookstate.Context, view *confdb.View, paths []string, constraints map[string]any) (tx *Transaction, err error) { st := hookCtx.State() - account, schemaName := view.Schema().Account, view.Schema().Name + account, schema := view.Schema().Account, view.Schema().Name if IsConfdbHookCtx(hookCtx) { // running in the context of a transaction, so if the referenced confdb @@ -701,37 +709,41 @@ func GetTransactionForSnapctlGet(hookCtx *hookstate.Context, view *confdb.View, return nil, fmt.Errorf("cannot load confdb view %s: cannot get transaction: %v", view.ID(), err) } - if tx.ConfdbAccount != account || tx.ConfdbName != schemaName { + if tx.ConfdbAccount != account || tx.ConfdbName != schema { // TODO: this should be enabled at some point - return nil, fmt.Errorf("cannot load confdb %s/%s: ongoing transaction for %s/%s", account, schemaName, tx.ConfdbAccount, tx.ConfdbName) + return nil, fmt.Errorf("cannot load confdb %s/%s: ongoing transaction for %s/%s", account, schema, tx.ConfdbAccount, tx.ConfdbName) } // we're reading the tx that this hook is modifying, just return that return tx, nil } - // TODO: replace this with the concurrent access logic. Derive timeout from hookstate.Context - // if not otherwise set? - txs, _, err := getOngoingTxs(st, account, schemaName) + ctx := context.Background() + if hookCtx.Timeout() != time.Duration(0) { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, hookCtx.Timeout()) + defer cancel() + } + + accessID, err := waitForAccess(ctx, st, view, readAccess) if err != nil { - return nil, fmt.Errorf("cannot access confdb view %s: cannot check ongoing transactions: %v", view.ID(), err) + return nil, err } - // TODO: use txs.CanStartReadTx() once we support blocking access here - if txs.WriteTxID != "" || len(txs.pending) > 0 { - // TODO: eventually we want to queue this load and block until we serve it. - // It might also be necessary to have some form of timeout. - return nil, fmt.Errorf("cannot access confdb view %s: ongoing write transaction", view.ID()) + // accessID is empty if we didn't release the lock and wait, so no state was + // modified and there aren't other accesses to unblock + if accessID != "" { + defer cleanupAccess(st, accessID, account, schema) } // not running in an existing confdb hook context, so create a transaction // and a change to load/modify data - tx, err := NewTransaction(st, account, schemaName) + tx, err = NewTransaction(st, account, schema) if err != nil { return nil, fmt.Errorf("cannot load confdb view %s: cannot create transaction: %v", view.ID(), err) } - ts, err := createLoadConfdbTasks(st, tx, view, paths, constraints) + ts, clearTxTask, err := createLoadConfdbTasks(st, tx, view, paths, constraints) if err != nil { return nil, err } @@ -750,13 +762,6 @@ func GetTransactionForSnapctlGet(hookCtx *hookstate.Context, view *confdb.View, chg = task.Change() } - chg.AddAll(ts) - - clearTxTask, err := ts.Edge(clearTxEdge) - if err != nil { - return nil, err - } - waitChan := make(chan struct{}) st.AddTaskStatusChangedHandler(func(t *state.Task, old, new state.Status) (remove bool) { if t.ID() == clearTxTask.ID() && new.Ready() { @@ -766,19 +771,27 @@ func GetTransactionForSnapctlGet(hookCtx *hookstate.Context, view *confdb.View, return false }) - err = addReadTransaction(st, account, schemaName, clearTxTask.ID()) + // schedule tasks after saving the tx ID so the deferred cleanup skips waking + // up waiters if a task will do it (len(txs.ReadTxIDs) > 0) + err = addReadTransaction(st, account, schema, clearTxTask.ID(), accessID) if err != nil { return nil, err } + chg.AddAll(ts) ensureNow(st) hookCtx.Unlock() + if blockingSignals["wait-for-change-done"] != nil { + // for testing purposes only + close(blockingSignals["wait-for-change-done"]) + } + select { case <-waitChan: case <-time.After(transactionTimeout): hookCtx.Lock() - return nil, fmt.Errorf("cannot load confdb %s/%s in change %s: timed out after %s", account, schemaName, chg.ID(), transactionTimeout) + return nil, fmt.Errorf("cannot load confdb %s/%s in change %s: timed out after %s", account, schema, chg.ID(), transactionTimeout) } hookCtx.Lock() @@ -795,53 +808,73 @@ const ( writeAccess accessType = "write" ) -type pendingAccess struct { +// access holds data for a pending access, namely a unique identifier, +// access type (read or write) and a channel use to signal that the access can +// proceed. +type access struct { // ID is a random string identifying this access. ID string - // AccessType denotes whether the access is read or write. Exported for - // testing purposes. + // AccessType denotes whether the access is read or write. AccessType accessType // WaitChan is closed to unblock the pending access. WaitChan chan<- struct{} } +// cleanupAccess removes state related to processing an access, if any exists +// (i.e., if the access had to wait and was eventually unblocked). If no tasks +// were scheduled and there aren't other accesses waiting to schedule, it unblocks +// the next pending accesses. +func cleanupAccess(st *state.State, accessID, account, schema string) { + txs, updateTxStateFunc, uerr := getOngoingTxs(st, account, schema) + if uerr != nil { + logger.Noticef("cannot unblock next access after failed access: %v", uerr) + return + } + defer updateTxStateFunc(txs) + + // remove this access from the scheduling list, if we haven't yet + for i, acc := range txs.Scheduling { + if acc.ID == accessID { + txs.Scheduling = append(txs.Scheduling[:i], txs.Scheduling[i+1:]...) + break + } + } + + // this may actually not unblock anything, if other accesses are being processed + uerr = maybeUnblockAccesses(txs) + if uerr != nil { + logger.Noticef("cannot unblock next access after failed access: %v", uerr) + } +} + // ReadConfdb schedules a change to load a confdb, running any appropriate // hooks and fulfilling the requests by reading the view and placing the // resulting data in the change's data (so it can be read by the client). func ReadConfdb(ctx context.Context, st *state.State, view *confdb.View, requests []string, constraints map[string]any, userAccess confdb.Access) (changeID string, err error) { - defer func() { - if err != nil { - uerr := unblockNextAccess(st, view.Schema().Account, view.Schema().Name) - if uerr != nil { - logger.Noticef("cannot unblock next access after failed read: %v", uerr) - } - } - }() - - err = waitForAccess(ctx, st, view, readAccess) + accessID, err := waitForAccess(ctx, st, view, readAccess) if err != nil { return "", err } account, schema := view.Schema().Account, view.Schema().Name + // accessID is empty if we didn't release the lock and wait, so no state was + // modified and there aren't other accesses to unblock + if accessID != "" { + defer cleanupAccess(st, accessID, account, schema) + } + tx, err := NewTransaction(st, account, schema) if err != nil { return "", fmt.Errorf("cannot access confdb view %s: cannot create transaction: %v", view.ID(), err) } - ts, err := createLoadConfdbTasks(st, tx, view, requests, constraints) + ts, clearTxTask, err := createLoadConfdbTasks(st, tx, view, requests, constraints) if err != nil { return "", err } chg := st.NewChange(getConfdbChangeKind, fmt.Sprintf(`Get confdb through %q`, view.ID())) if ts != nil { - // if there are hooks to run, link the read-confdb task to those tasks - clearTxTask, err := ts.Edge(clearTxEdge) - if err != nil { - return "", err - } - // schedule a task to read the tx after the hook and add the data to the // change so it can be read by the client loadConfdbTask := st.NewTask("load-confdb-change", "Load confdb data into the change") @@ -854,7 +887,9 @@ func ReadConfdb(ctx context.Context, st *state.State, view *confdb.View, request loadConfdbTask.WaitFor(clearTxTask) chg.AddAll(ts) - err = addReadTransaction(st, account, schema, clearTxTask.ID()) + // schedule tasks after saving the tx ID so the deferred cleanup skips waking + // up waiters if a task will do it (len(txs.ReadTxIDs) > 0) + err = addReadTransaction(st, account, schema, clearTxTask.ID(), accessID) if err != nil { return "", err } @@ -876,14 +911,14 @@ func ReadConfdb(ctx context.Context, st *state.State, view *confdb.View, request // read a transaction through the given view. In case no custodian snap has any // load-view or query-view hooks, nil is returned. If there are hooks to run, // a clear-confdb-tx task is also scheduled to remove the ongoing transaction at the end. -func createLoadConfdbTasks(st *state.State, tx *Transaction, view *confdb.View, requests []string, constraints map[string]any) (*state.TaskSet, error) { +func createLoadConfdbTasks(st *state.State, tx *Transaction, view *confdb.View, requests []string, constraints map[string]any) (*state.TaskSet, *state.Task, error) { custodians, custodianPlugs, err := getCustodianPlugsForView(st, view) if err != nil { - return nil, err + return nil, nil, err } if len(custodians) == 0 { - return nil, fmt.Errorf("cannot load confdb through view %s: no custodian snap connected", view.ID()) + return nil, nil, fmt.Errorf("cannot load confdb through view %s: no custodian snap connected", view.ID()) } ts := state.NewTaskSet() @@ -897,7 +932,7 @@ func createLoadConfdbTasks(st *state.State, tx *Transaction, view *confdb.View, mightAffectEph, err := view.ReadAffectsEphemeral(requests, constraints) if err != nil { - return nil, err + return nil, nil, err } hookPrefixes := []string{"load-view-", "query-view-"} @@ -921,14 +956,14 @@ func createLoadConfdbTasks(st *state.State, tx *Transaction, view *confdb.View, // there must be least one load-view hook if we're accessing ephemeral data if hookPrefix == "load-view-" && mightAffectEph && !loadViewHookPresent { - return nil, fmt.Errorf("cannot schedule tasks to access %s: read might cover ephemeral data but no custodian has a load-view hook", view.ID()) + return nil, nil, fmt.Errorf("cannot schedule tasks to access %s: read might cover ephemeral data but no custodian has a load-view hook", view.ID()) } } if len(hooks) == 0 { // no hooks to run and not running from API (don't need task to populate) // data in change so we can just read the databag synchronously - return nil, nil + return nil, nil, nil } // clear the tx from the state if the change fails @@ -945,11 +980,9 @@ func createLoadConfdbTasks(st *state.State, tx *Transaction, view *confdb.View, for _, t := range ts.Tasks() { t.Set("tx-task", clearTxTask.ID()) } - linkTask(clearTxTask) - ts.MarkEdge(clearTxTask, clearTxEdge) - return ts, nil + return ts, clearTxTask, nil } func MockFetchConfdbSchemaAssertion(f func(*state.State, int, string, string) error) func() { diff --git a/overlord/confdbstate/confdbstate_test.go b/overlord/confdbstate/confdbstate_test.go index 5ae52a227f1..5690bfc2533 100644 --- a/overlord/confdbstate/confdbstate_test.go +++ b/overlord/confdbstate/confdbstate_test.go @@ -43,7 +43,6 @@ import ( "github.com/snapcore/snapd/overlord/confdbstate" "github.com/snapcore/snapd/overlord/configstate/config" "github.com/snapcore/snapd/overlord/hookstate" - "github.com/snapcore/snapd/overlord/hookstate/ctlcmd" "github.com/snapcore/snapd/overlord/hookstate/hooktest" "github.com/snapcore/snapd/overlord/ifacestate/ifacerepo" "github.com/snapcore/snapd/overlord/snapstate" @@ -193,7 +192,7 @@ func (s *confdbTestSuite) SetUpTest(c *C) { c.Assert(err, IsNil) tr.Commit() - confdbstate.SetBlockingSignalChan(nil) + confdbstate.ResetBlockingSignals() } func parsePath(c *C, path string) []confdb.Accessor { @@ -365,7 +364,7 @@ func (s *confdbTestSuite) TestUnsetView(c *C) { c.Assert(err, testutil.ErrorIs, &confdb.NoDataError{}) } -func (s *confdbTestSuite) TestConfdbstateGetEntireView(c *C) { +func (s *confdbTestSuite) TestGetEntireView(c *C) { s.state.Lock() defer s.state.Unlock() @@ -613,18 +612,12 @@ func (s *confdbTestSuite) TestConfdbTasksUserSetWithCustodianInstalled(c *C) { chg := s.state.NewChange("modify-confdb", "") // a user (not a snap) changes a confdb - ts, err := confdbstate.CreateChangeConfdbTasks(s.state, tx, view, "") - c.Assert(err, IsNil) - chg.AddAll(ts) - - // there are two edges in the taskset - commitTask, err := ts.Edge(confdbstate.CommitEdge) + ts, commitTask, clearTask, err := confdbstate.CreateChangeConfdbTasks(s.state, tx, view, "") c.Assert(err, IsNil) c.Assert(commitTask.Kind(), Equals, "commit-confdb-tx") + c.Assert(clearTask.Kind(), Equals, "clear-confdb-tx") - cleanupTask, err := ts.Edge(confdbstate.ClearTxEdge) - c.Assert(err, IsNil) - c.Assert(cleanupTask.Kind(), Equals, "clear-confdb-tx") + chg.AddAll(ts) // the custodian snap's hooks are run tasks := []string{"clear-confdb-tx-on-error", "run-hook", "run-hook", "run-hook", "commit-confdb-tx", "clear-confdb-tx"} @@ -670,7 +663,7 @@ func (s *confdbTestSuite) TestConfdbTasksCustodianSnapSet(c *C) { chg := s.state.NewChange("set-confdb", "") // a user (not a snap) changes a confdb - ts, err := confdbstate.CreateChangeConfdbTasks(s.state, tx, view, "custodian-snap") + ts, _, _, err := confdbstate.CreateChangeConfdbTasks(s.state, tx, view, "custodian-snap") c.Assert(err, IsNil) chg.AddAll(ts) @@ -713,7 +706,7 @@ func (s *confdbTestSuite) TestConfdbTasksObserverSnapSetWithCustodianInstalled(c chg := s.state.NewChange("modify-confdb", "") // a non-custodian snap modifies a confdb - ts, err := confdbstate.CreateChangeConfdbTasks(s.state, tx, view, "test-snap-1") + ts, _, _, err := confdbstate.CreateChangeConfdbTasks(s.state, tx, view, "test-snap-1") c.Assert(err, IsNil) chg.AddAll(ts) @@ -784,7 +777,7 @@ func (s *confdbTestSuite) testConfdbTasksNoCustodian(c *C) { view := s.dbSchema.View("setup-wifi") // a non-custodian snap modifies a confdb - _, err = confdbstate.CreateChangeConfdbTasks(s.state, tx, view, "test-snap-1") + _, _, _, err = confdbstate.CreateChangeConfdbTasks(s.state, tx, view, "test-snap-1") c.Assert(err, ErrorMatches, fmt.Sprintf("cannot commit changes to confdb made through view %s/network/%s: no custodian snap connected", s.devAccID, view.Name)) } @@ -1019,7 +1012,7 @@ func (s *confdbTestSuite) checkOngoingWriteConfdbTx(c *C, account, confdbName st c.Assert(commitTask.Status(), Equals, state.DoStatus) } -func (s *confdbTestSuite) TestGetTransactionFromUserCreatesNewChange(c *C) { +func (s *confdbTestSuite) TestWriteConfdbCreatesNewChange(c *C) { hooks, restore := s.mockConfdbHooks() defer restore() @@ -1038,37 +1031,24 @@ func (s *confdbTestSuite) TestGetTransactionFromUserCreatesNewChange(c *C) { s.setupConfdbScenario(c, custodians, nil) view := s.dbSchema.View("setup-wifi") - - tx, commitTxFunc, err := confdbstate.GetTransactionToSet(nil, s.state, view) - c.Assert(err, IsNil) - c.Assert(tx, NotNil) - c.Assert(commitTxFunc, NotNil) - - err = tx.Set(parsePath(c, "wifi.ssid"), "foo") - c.Assert(err, IsNil) - - // mock the daemon triggering the commit - changeID, waitChan, err := commitTxFunc() + chgID, err := confdbstate.WriteConfdb(context.Background(), s.state, view, map[string]any{ + "ssid": "foo", + }) c.Assert(err, IsNil) - s.state.Unlock() - select { - case <-waitChan: - case <-time.After(testutil.HostScaledTimeout(5 * time.Second)): - s.state.Lock() - c.Fatal("test timed out after 5s") - } - s.state.Lock() - c.Assert(s.state.Changes(), HasLen, 1) chg := s.state.Changes()[0] c.Assert(chg.Kind(), Equals, "set-confdb") - c.Assert(changeID, Equals, chg.ID()) + c.Assert(chg.ID(), Equals, chgID) + + s.state.Unlock() + s.o.Settle(testutil.HostScaledTimeout(5 * time.Second)) + s.state.Lock() s.checkSetConfdbChange(c, chg, hooks) } -func (s *confdbTestSuite) TestGetTransactionFromSnapCreatesNewChange(c *C) { +func (s *confdbTestSuite) TestWriteConfdbFromSnapCreatesNewChange(c *C) { hooks, restore := s.mockConfdbHooks() defer restore() @@ -1081,6 +1061,7 @@ func (s *confdbTestSuite) TestGetTransactionFromSnapCreatesNewChange(c *C) { s.state.Lock() defer s.state.Unlock() + view := s.dbSchema.View("setup-wifi") // only one custodian snap is installed custodians := map[string]confdbHooks{"custodian-snap": allHooks} @@ -1089,15 +1070,13 @@ func (s *confdbTestSuite) TestGetTransactionFromSnapCreatesNewChange(c *C) { ctx, err := hookstate.NewContext(nil, s.state, &hookstate.HookSetup{Snap: "test-snap"}, nil, "") c.Assert(err, IsNil) - s.state.Unlock() - stdout, stderr, err := ctlcmd.Run(ctx, []string{"set", "--view", ":setup", "ssid=foo"}, 0, nil) + + ctx.Lock() + err = confdbstate.WriteConfdbFromSnap(ctx, view, map[string]any{"ssid": "foo"}) c.Assert(err, IsNil) - c.Check(stdout, IsNil) - c.Check(stderr, IsNil) // this is called automatically by hooks or manually for daemon/ - ctx.Lock() ctx.Done() ctx.Unlock() @@ -1110,20 +1089,24 @@ func (s *confdbTestSuite) TestGetTransactionFromSnapCreatesNewChange(c *C) { } func (s *confdbTestSuite) TestGetTransactionFromNonConfdbHookAddsConfdbTx(c *C) { + view := s.dbSchema.View("setup-wifi") + var hooks []string restore := hookstate.MockRunHook(func(ctx *hookstate.Context, _ *tomb.Tomb) ([]byte, error) { t, _ := ctx.Task() - ctx.State().Lock() + s.state.Lock() var hooksup *hookstate.HookSetup err := t.Get("hook-setup", &hooksup) - ctx.State().Unlock() + s.state.Unlock() if err != nil { return nil, err } if hooksup.Hook == "install" { - _, _, err := ctlcmd.Run(ctx, []string{"set", "--view", ":setup", "ssid=foo"}, 0, nil) + ctx.Lock() + err := confdbstate.WriteConfdbFromSnap(ctx, view, map[string]any{"ssid": "foo"}) + ctx.Unlock() c.Assert(err, IsNil) return nil, nil } @@ -1220,56 +1203,8 @@ func (s *confdbTestSuite) checkSetConfdbChange(c *C, chg *state.Change, hooks *[ c.Assert(val, Equals, "foo") } -func (s *confdbTestSuite) TestGetTransactionFromChangeViewHook(c *C) { - ctx := s.testGetReadableOngoingTransaction(c, "change-view-setup") - - // change-view hooks can also write to the transaction - stdout, stderr, err := ctlcmd.Run(ctx, []string{"set", "--view", ":setup", "ssid=bar"}, 0, nil) - c.Assert(err, IsNil) - // accessed an ongoing transaction - c.Assert(stdout, IsNil) - c.Assert(stderr, IsNil) - - // this save the changes that the hook performs - ctx.Lock() - ctx.Done() - ctx.Unlock() - - s.state.Lock() - defer s.state.Unlock() - t, _ := ctx.Task() - tx, _, _, err := confdbstate.GetStoredTransaction(t) - c.Assert(err, IsNil) - - val, err := tx.Get(parsePath(c, "wifi.ssid"), nil) - c.Assert(err, IsNil) - c.Assert(val, Equals, "bar") -} - -func (s *confdbTestSuite) TestGetTransactionFromSaveViewHook(c *C) { - ctx := s.testGetReadableOngoingTransaction(c, "save-view-setup") - - // non change-view hooks cannot modify the transaction - stdout, stderr, err := ctlcmd.Run(ctx, []string{"set", "--view", ":setup", "ssid=bar"}, 0, nil) - c.Assert(err, ErrorMatches, `cannot modify confdb in "save-view-setup" hook`) - c.Assert(stdout, IsNil) - c.Assert(stderr, IsNil) -} - -func (s *confdbTestSuite) TestGetTransactionFromViewChangedHook(c *C) { - ctx := s.testGetReadableOngoingTransaction(c, "observe-view-setup") - - // non change-view hooks cannot modify the transaction - stdout, stderr, err := ctlcmd.Run(ctx, []string{"set", "--view", ":setup", "ssid=bar"}, 0, nil) - c.Assert(err, ErrorMatches, `cannot modify confdb in "observe-view-setup" hook`) - c.Assert(stdout, IsNil) - c.Assert(stderr, IsNil) -} - -func (s *confdbTestSuite) testGetReadableOngoingTransaction(c *C, hook string) *hookstate.Context { +func (s *confdbTestSuite) TestWriteConfdbFromChangeViewHook(c *C) { s.state.Lock() - defer s.state.Unlock() - custodians := map[string]confdbHooks{"custodian-snap": allHooks} s.setupConfdbScenario(c, custodians, []string{"test-snap"}) @@ -1286,22 +1221,42 @@ func (s *confdbTestSuite) testGetReadableOngoingTransaction(c *C, hook string) * hookTask := s.state.NewTask("run-hook", "") chg.AddTask(hookTask) - setup := &hookstate.HookSetup{Snap: "test-snap", Revision: snap.R(1), Hook: hook} + setup := &hookstate.HookSetup{Snap: "test-snap", Revision: snap.R(1), Hook: "change-view-setup"} mockHandler := hooktest.NewMockHandler() hookTask.Set("tx-task", commitTask.ID()) + s.state.Unlock() ctx, err := hookstate.NewContext(hookTask, s.state, setup, mockHandler, "") c.Assert(err, IsNil) - s.state.Unlock() - stdout, stderr, err := ctlcmd.Run(ctx, []string{"get", "--view", ":setup", "ssid"}, 0, nil) - s.state.Lock() + ctx.Lock() + view := s.dbSchema.View("setup-wifi") + tx, err := confdbstate.ReadConfdbFromSnap(ctx, view, []string{"ssid"}, nil) c.Assert(err, IsNil) // accessed an ongoing transaction - c.Assert(string(stdout), Equals, "foo\n") - c.Assert(stderr, IsNil) + data, err := tx.Get(parsePath(c, "wifi.ssid"), nil) + c.Assert(err, IsNil) + c.Assert(data, Equals, "foo") + + // change-view hooks can also write to the transaction + err = confdbstate.WriteConfdbFromSnap(ctx, view, map[string]any{ + "ssid": "bar", + }) + c.Assert(err, IsNil) + + // accessed an ongoing transaction so save the changes made by the hook + ctx.Done() + ctx.Unlock() + + s.state.Lock() + defer s.state.Unlock() + t, _ := ctx.Task() + tx, _, _, err = confdbstate.GetStoredTransaction(t) + c.Assert(err, IsNil) - return ctx + val, err := tx.Get(parsePath(c, "wifi.ssid"), nil) + c.Assert(err, IsNil) + c.Assert(val, Equals, "bar") } func (s *confdbTestSuite) TestGetDifferentTransactionThanOngoing(c *C) { @@ -1334,11 +1289,10 @@ func (s *confdbTestSuite) TestGetDifferentTransactionThanOngoing(c *C) { c.Assert(err, IsNil) ctx.Lock() - tx, commitTxFunc, err := confdbstate.GetTransactionToSet(ctx, s.state, confdb.View("foo")) + view := confdb.View("foo") + err = confdbstate.WriteConfdbFromSnap(ctx, view, nil) ctx.Unlock() c.Assert(err, ErrorMatches, fmt.Sprintf(`cannot access confdb through view foo/bar/foo: ongoing transaction for %s/network`, s.devAccID)) - c.Assert(tx, IsNil) - c.Assert(commitTxFunc, IsNil) } func (s *confdbTestSuite) TestConfdbLoadDisconnectedCustodianSnap(c *C) { @@ -1373,7 +1327,7 @@ func (s *confdbTestSuite) testConfdbLoadNoCustodian(c *C) { view := s.dbSchema.View("setup-wifi") // a non-custodian snap modifies a confdb - _, err = confdbstate.CreateLoadConfdbTasks(s.state, tx, view, []string{"ssid"}, nil) + _, _, err = confdbstate.CreateLoadConfdbTasks(s.state, tx, view, []string{"ssid"}, nil) c.Assert(err, ErrorMatches, fmt.Sprintf("cannot load confdb through view %s/network/setup-wifi: no custodian snap connected", s.devAccID)) } @@ -1434,12 +1388,9 @@ func (s *confdbTestSuite) TestConfdbLoadCustodianInstalled(c *C) { view := s.dbSchema.View("setup-wifi") chg := s.state.NewChange("load-confdb", "") - ts, err := confdbstate.CreateLoadConfdbTasks(s.state, tx, view, []string{"ssid"}, nil) + ts, cleanupTask, err := confdbstate.CreateLoadConfdbTasks(s.state, tx, view, []string{"ssid"}, nil) c.Assert(err, IsNil) chg.AddAll(ts) - - cleanupTask, err := ts.Edge(confdbstate.ClearTxEdge) - c.Assert(err, IsNil) c.Assert(cleanupTask.Kind(), Equals, "clear-confdb-tx") // the custodian snap's hooks are run @@ -1477,7 +1428,7 @@ func (s *confdbTestSuite) TestConfdbLoadCustodianWithNoHooks(c *C) { c.Assert(err, IsNil) view := s.dbSchema.View("setup-wifi") - ts, err := confdbstate.CreateLoadConfdbTasks(s.state, tx, view, []string{"ssid"}, nil) + ts, _, err := confdbstate.CreateLoadConfdbTasks(s.state, tx, view, []string{"ssid"}, nil) c.Assert(err, IsNil) // no hooks, nothing to run c.Assert(ts, IsNil) @@ -1498,7 +1449,7 @@ func (s *confdbTestSuite) TestConfdbLoadTasks(c *C) { c.Assert(err, IsNil) view := s.dbSchema.View("setup-wifi") - ts, err := confdbstate.CreateLoadConfdbTasks(s.state, tx, view, []string{"ssid"}, nil) + ts, _, err := confdbstate.CreateLoadConfdbTasks(s.state, tx, view, []string{"ssid"}, nil) c.Assert(err, IsNil) chg := s.state.NewChange("get-confdb", "") chg.AddAll(ts) @@ -1521,18 +1472,19 @@ func (s *confdbTestSuite) TestConfdbLoadTasks(c *C) { checkLoadConfdbTasks(c, chg, tasks, hooks) } -func (s *confdbTestSuite) TestGetTransactionForSnapctlNoHook(c *C) { +func (s *confdbTestSuite) TestReadConfdbFromSnapEphemeral(c *C) { s.state.Lock() // only one custodian snap is installed custodians := map[string]confdbHooks{"custodian-snap": allHooks} s.setupConfdbScenario(c, custodians, nil) mockHandler := hooktest.NewMockHandler() - ctx, err := hookstate.NewContext(nil, s.state, nil, mockHandler, "") + setup := &hookstate.HookSetup{Snap: "test-snap", Hook: "change-view-setup"} + ctx, err := hookstate.NewContext(nil, s.state, setup, mockHandler, "") c.Assert(err, IsNil) s.state.Unlock() - chg := s.testGetTransactionForSnapctl(c, ctx) + chg := s.testReadConfdbFromSnap(c, ctx) s.state.Lock() defer s.state.Unlock() @@ -1563,18 +1515,16 @@ func (s *confdbTestSuite) TestGetTransactionForSnapctlNonConfdbHook(c *C) { c.Assert(err, IsNil) s.state.Unlock() - s.testGetTransactionForSnapctl(c, ctx) + s.testReadConfdbFromSnap(c, ctx) } -func (s *confdbTestSuite) testGetTransactionForSnapctl(c *C, ctx *hookstate.Context) *state.Change { +func (s *confdbTestSuite) testReadConfdbFromSnap(c *C, ctx *hookstate.Context) *state.Change { hooks, restore := s.mockConfdbHooks() defer restore() restore = confdbstate.MockEnsureNow(func(*state.State) { s.checkOngoingReadConfdbTx(c, s.devAccID, "network") - go func() { - s.o.Settle(5 * time.Second) - }() + go s.o.Settle(5 * time.Second) }) defer restore() @@ -1589,7 +1539,7 @@ func (s *confdbTestSuite) testGetTransactionForSnapctl(c *C, ctx *hookstate.Cont s.state.Set("confdb-databags", map[string]map[string]confdb.JSONDatabag{s.devAccID: {"network": bag}}) view := s.dbSchema.View("setup-wifi") - tx, err := confdbstate.GetTransactionForSnapctlGet(ctx, view, []string{"ssid"}, nil) + tx, err := confdbstate.ReadConfdbFromSnap(ctx, view, []string{"ssid"}, nil) c.Assert(err, IsNil) c.Assert(s.state.Changes(), HasLen, 1) @@ -1630,7 +1580,7 @@ func (s *confdbTestSuite) TestGetTransactionInConfdbHook(c *C) { c.Assert(err, IsNil) view := s.dbSchema.View("setup-wifi") - tx, err := confdbstate.GetTransactionForSnapctlGet(ctx, view, []string{"ssid"}, nil) + tx, err := confdbstate.ReadConfdbFromSnap(ctx, view, []string{"ssid"}, nil) c.Assert(err, IsNil) // reads synchronously without creating new change or tasks c.Assert(s.state.Changes(), HasLen, 1) @@ -1656,19 +1606,24 @@ func (s *confdbTestSuite) TestGetTransactionNoConfdbHooks(c *C) { Hook: "install", } hookTask.Set("hook-setup", hooksup) - mockHandler := hooktest.NewMockHandler() - ctx, err := hookstate.NewContext(hookTask, s.state, hooksup, mockHandler, "") - c.Assert(err, IsNil) // write some value for the get to read bag := confdb.NewJSONDatabag() - err = bag.Set(parsePath(c, "wifi.ssid"), "foo") + err := bag.Set(parsePath(c, "wifi.ssid"), "foo") c.Assert(err, IsNil) s.state.Set("confdb-databags", map[string]map[string]confdb.JSONDatabag{s.devAccID: {"network": bag}}) + mockHandler := hooktest.NewMockHandler() + ctx, err := hookstate.NewContext(hookTask, s.state, hooksup, mockHandler, "") + c.Assert(err, IsNil) + view := s.dbSchema.View("setup-wifi") - tx, err := confdbstate.GetTransactionForSnapctlGet(ctx, view, []string{"ssid"}, nil) + s.state.Unlock() + ctx.Lock() + tx, err := confdbstate.ReadConfdbFromSnap(ctx, view, []string{"ssid"}, nil) + ctx.Unlock() + s.state.Lock() c.Assert(err, IsNil) c.Assert(tx, NotNil) @@ -1694,7 +1649,8 @@ func (s *confdbTestSuite) TestGetTransactionTimesOut(c *C) { s.setupConfdbScenario(c, custodians, nil) mockHandler := hooktest.NewMockHandler() - ctx, err := hookstate.NewContext(nil, s.state, nil, mockHandler, "") + setup := &hookstate.HookSetup{Snap: "test-snap", Hook: "change-view-setup"} + ctx, err := hookstate.NewContext(nil, s.state, setup, mockHandler, "") c.Assert(err, IsNil) // write some value for the get to read @@ -1709,7 +1665,7 @@ func (s *confdbTestSuite) TestGetTransactionTimesOut(c *C) { ctx.Lock() defer ctx.Unlock() - tx, err := confdbstate.GetTransactionForSnapctlGet(ctx, view, nil, nil) + tx, err := confdbstate.ReadConfdbFromSnap(ctx, view, nil, nil) c.Assert(err, ErrorMatches, fmt.Sprintf("cannot load confdb %s/network in change 1: timed out after 0s", s.devAccID)) c.Assert(tx, IsNil) } @@ -1781,7 +1737,7 @@ func (s *confdbTestSuite) checkOngoingReadConfdbTx(c *C, account, confdbName str c.Assert(clearTask.Status(), Equals, state.DoStatus) } -func (s *confdbTestSuite) TestGetTransactionForAPI(c *C) { +func (s *confdbTestSuite) TestAPIReadConfdb(c *C) { s.state.Lock() custodians := map[string]confdbHooks{"custodian-snap": allHooks} nonCustodians := []string{"test-snap"} @@ -1837,7 +1793,7 @@ func (s *confdbTestSuite) TestGetTransactionForAPI(c *C) { }) } -func (s *confdbTestSuite) TestGetTransactionForAPINoHooks(c *C) { +func (s *confdbTestSuite) TestReadConfdbNoHooks(c *C) { s.state.Lock() defer s.state.Unlock() @@ -1881,7 +1837,76 @@ func (s *confdbTestSuite) TestGetTransactionForAPINoHooks(c *C) { }) } -func (s *confdbTestSuite) TestGetTransactionForAPINoHooksError(c *C) { +func (s *confdbTestSuite) TestReadConfdbNoHooksUnblocksNextPendingAccess(c *C) { + s.state.Lock() + + custodians := map[string]confdbHooks{"custodian-snap": noHooks} + nonCustodians := []string{"test-snap"} + s.setupConfdbScenario(c, custodians, nonCustodians) + + view := s.dbSchema.View("setup-wifi") + ref := view.Schema().Account + "/" + view.Schema().Name + s.state.Set("confdb-ongoing-txs", map[string]*confdbstate.ConfdbTransactions{ + ref: {WriteTxID: "10"}, + }) + + // testing helper closed when the access is about to block + blockingChan := make(chan struct{}) + confdbstate.SetBlockingSignal("wait-for-access", blockingChan) + + var chgID string + doneChan := make(chan struct{}) + go func() { + var err error + chgID, err = confdbstate.ReadConfdb(context.Background(), s.state, view, []string{"ssid"}, nil, 0) + c.Assert(err, IsNil) + s.state.Unlock() + close(doneChan) + }() + + select { + case <-blockingChan: + case <-time.After(testutil.HostScaledTimeout(2 * time.Second)): + c.Fatal("expected access to block but timed out") + } + + // the blocked read released the lock before waiting + s.state.Lock() + accs, ok := s.state.Cached("pending-confdb-" + ref).([]confdbstate.Access) + c.Assert(ok, Equals, true) + c.Assert(accs, HasLen, 1) + c.Assert(accs[0].AccessType, Equals, confdbstate.AccessType("read")) + + nextWaitChan := make(chan struct{}, 1) + s.endOngoingAccess(c, &confdbstate.Access{ + ID: "next-write", + AccessType: confdbstate.AccessType("write"), + WaitChan: nextWaitChan, + }) + s.state.Unlock() + + select { + case <-nextWaitChan: + case <-time.After(testutil.HostScaledTimeout(2 * time.Second)): + c.Fatal("expected next access to be unblocked but timed out") + } + + select { + case <-doneChan: + case <-time.After(testutil.HostScaledTimeout(2 * time.Second)): + c.Fatal("expected read to complete but timed out") + } + + s.state.Lock() + defer s.state.Unlock() + + chg := s.state.Change(chgID) + c.Assert(chg, NotNil) + c.Assert(chg.Tasks(), HasLen, 0) + c.Assert(chg.Status(), Equals, state.DoneStatus) +} + +func (s *confdbTestSuite) TestAPIReadConfdbNoHooksError(c *C) { s.state.Lock() defer s.state.Unlock() @@ -1915,7 +1940,7 @@ func (s *confdbTestSuite) TestGetTransactionForAPINoHooksError(c *C) { c.Assert(errKind, Equals, "option-not-found") } -func (s *confdbTestSuite) TestGetTransactionForAPIError(c *C) { +func (s *confdbTestSuite) TestAPIReadConfdbError(c *C) { s.state.Lock() custodians := map[string]confdbHooks{"custodian-snap": allHooks} nonCustodians := []string{"test-snap"} @@ -1949,87 +1974,38 @@ func (s *confdbTestSuite) TestGetTransactionForAPIError(c *C) { c.Assert(errKind, Equals, "option-not-found") } -// TODO: replace these tests once the snapctl flow is also blocking -func (s *confdbTestSuite) TestConcurrentAccessWithOngoingWrite(c *C) { +func (s *confdbTestSuite) TestWriteAffectingEphemeralMustDefineSaveViewHook(c *C) { s.state.Lock() - defer s.state.Unlock() - - s.setupConfdbScenario(c, map[string]confdbHooks{"custodian-snap": allHooks}, nil) - _, restore := s.mockConfdbHooks() - defer restore() - - err := confdbstate.SetWriteTransaction(s.state, s.devAccID, "network", "1") - c.Assert(err, IsNil) - - view := s.dbSchema.View("setup-wifi") - - // reading from the snap - mockHandler := hooktest.NewMockHandler() - ctx, err := hookstate.NewContext(nil, s.state, nil, mockHandler, "") - c.Assert(err, IsNil) - - _, err = confdbstate.GetTransactionForSnapctlGet(ctx, view, nil, nil) - c.Assert(err, ErrorMatches, fmt.Sprintf("cannot access confdb view %s/network/setup-wifi: ongoing write transaction", s.devAccID)) - - // writing (used both from snap or API) - _, _, err = confdbstate.GetTransactionToSet(nil, s.state, view) - c.Assert(err, ErrorMatches, fmt.Sprintf("cannot write confdb through view %s/network/setup-wifi: ongoing transaction", s.devAccID)) -} + hooks := observeView | queryView | loadView | changeView + s.setupConfdbScenario(c, map[string]confdbHooks{"custodian-snap": hooks}, nil) + s.state.Unlock() -func (s *confdbTestSuite) TestConcurrentAccessWithOngoingRead(c *C) { - s.state.Lock() - // it's better not to have hooks here because if we do the GetTransactionForSnapctlGet - // needs to schedule tasks and will block on them, making this test more timing based/annoying - s.setupConfdbScenario(c, map[string]confdbHooks{"custodian-snap": noHooks}, nil) + restore := confdbstate.MockEnsureNow(func(*state.State) { + s.checkOngoingWriteConfdbTx(c, s.devAccID, "network") - err := confdbstate.AddReadTransaction(s.state, s.devAccID, "network", "1") - c.Assert(err, IsNil) - s.state.Unlock() + go s.o.Settle(testutil.HostScaledTimeout(5 * time.Second)) + }) + defer restore() mockHandler := hooktest.NewMockHandler() - ctx, err := hookstate.NewContext(nil, s.state, nil, mockHandler, "") + setup := &hookstate.HookSetup{Snap: "test-snap", Revision: snap.R(1)} + ctx, err := hookstate.NewContext(nil, s.state, setup, mockHandler, "") c.Assert(err, IsNil) ctx.Lock() defer ctx.Unlock() - - view := s.dbSchema.View("setup-wifi") - // writing (used both from snap or API) conflicts - _, _, err = confdbstate.GetTransactionToSet(ctx, s.state, view) - c.Assert(err, ErrorMatches, fmt.Sprintf("cannot write confdb through view %s/network/setup-wifi: ongoing transaction", s.devAccID)) - - // we can read from the API and the snap concurrently with other reads - _, err = confdbstate.ReadConfdb(context.Background(), s.state, view, []string{"ssid"}, nil, confdb.AdminAccess) - c.Assert(err, IsNil) - - _, err = confdbstate.GetTransactionForSnapctlGet(ctx, view, []string{"ssid"}, nil) - c.Assert(err, IsNil) -} - -func (s *confdbTestSuite) TestWriteAffectingEphemeralMustDefineSaveViewHook(c *C) { - s.state.Lock() - defer s.state.Unlock() - - hooks := observeView | queryView | loadView | changeView - s.setupConfdbScenario(c, map[string]confdbHooks{"custodian-snap": hooks}, nil) - view := s.dbSchema.View("setup-wifi") - tx, commitTx, err := confdbstate.GetTransactionToSet(nil, s.state, view) - c.Assert(err, IsNil) - err = tx.Set(parsePath(c, "wifi.eph"), "foo") - c.Assert(err, IsNil) // can't write an ephemeral path w/o a save-view hook - _, _, err = commitTx() + err = confdbstate.WriteConfdbFromSnap(ctx, view, map[string]any{ + "eph": "foo", + }) c.Assert(err, ErrorMatches, fmt.Sprintf("cannot access %s/network/setup-wifi: write might change ephemeral data but no custodians has a save-view hook", s.devAccID)) - err = tx.Clear(s.state) - c.Assert(err, IsNil) - err = tx.Set(parsePath(c, "wifi.ssid"), "foo") - c.Assert(err, IsNil) - // but we can if the path can't touch any ephemeral data - _, _, err = commitTx() + err = confdbstate.WriteConfdbFromSnap(ctx, view, map[string]any{ + "ssid": "foo", + }) c.Assert(err, IsNil) } @@ -2039,14 +2015,15 @@ func (s *confdbTestSuite) TestReadCoveringEphemeralMustDefineLoadViewHook(c *C) s.setupConfdbScenario(c, map[string]confdbHooks{"custodian-snap": hooks}, nil) mockHandler := hooktest.NewMockHandler() - ctx, err := hookstate.NewContext(nil, s.state, nil, mockHandler, "") + setup := &hookstate.HookSetup{Snap: "test-snap", Revision: snap.R(1)} + ctx, err := hookstate.NewContext(nil, s.state, setup, mockHandler, "") c.Assert(err, IsNil) s.state.Unlock() ctx.Lock() view := s.dbSchema.View("setup-wifi") // can't read an ephemeral path w/o a load-view hook - _, err = confdbstate.GetTransactionForSnapctlGet(ctx, view, []string{"eph"}, nil) + _, err = confdbstate.ReadConfdbFromSnap(ctx, view, []string{"eph"}, nil) c.Assert(err, ErrorMatches, fmt.Sprintf("cannot schedule tasks to access %s/network/setup-wifi: read might cover ephemeral data but no custodian has a load-view hook", s.devAccID)) // so we don't block on the read @@ -2054,7 +2031,7 @@ func (s *confdbTestSuite) TestReadCoveringEphemeralMustDefineLoadViewHook(c *C) defer restore() // but if the path isn't ephemeral it's fine - _, err = confdbstate.GetTransactionForSnapctlGet(ctx, view, []string{"ssid"}, nil) + _, err = confdbstate.ReadConfdbFromSnap(ctx, view, []string{"ssid"}, nil) c.Assert(err, ErrorMatches, fmt.Sprintf("cannot load confdb %s/network in change 1: timed out after 0s", s.devAccID)) ctx.Unlock() @@ -2074,7 +2051,8 @@ func (s *confdbTestSuite) TestBadPathHookChecks(c *C) { s.setupConfdbScenario(c, map[string]confdbHooks{"custodian-snap": allHooks}, nil) mockHandler := hooktest.NewMockHandler() - ctx, err := hookstate.NewContext(nil, s.state, nil, mockHandler, "") + setup := &hookstate.HookSetup{Snap: "test-snap", Hook: "change-view-setup"} + ctx, err := hookstate.NewContext(nil, s.state, setup, mockHandler, "") c.Assert(err, IsNil) s.state.Unlock() @@ -2082,18 +2060,51 @@ func (s *confdbTestSuite) TestBadPathHookChecks(c *C) { defer ctx.Unlock() view := s.dbSchema.View("setup-wifi") - _, err = confdbstate.GetTransactionForSnapctlGet(ctx, view, []string{"foo"}, nil) + _, err = confdbstate.ReadConfdbFromSnap(ctx, view, []string{"foo"}, nil) c.Assert(err, ErrorMatches, fmt.Sprintf(`cannot get "foo" through %s/network/setup-wifi: no matching rule`, s.devAccID)) _, err = confdbstate.ReadConfdb(context.Background(), s.state, view, []string{"foo"}, nil, confdb.AdminAccess) c.Assert(err, ErrorMatches, fmt.Sprintf(`cannot get "foo" through %s/network/setup-wifi: no matching rule`, s.devAccID)) - tx, commitTxFunc, err := confdbstate.GetTransactionToSet(nil, s.state, view) - c.Assert(err, IsNil) - // this shouldn't happen unless there's a mismatch between views and schemas but check we're robust - c.Assert(tx.Set(parsePath(c, "foo"), "bar"), IsNil) - _, _, err = commitTxFunc() - c.Assert(err, ErrorMatches, `cannot check if write affects ephemeral data: cannot use "foo" as key in map`) + err = confdbstate.WriteConfdbFromSnap(ctx, view, map[string]any{"foo": "bar"}) + c.Assert(err, ErrorMatches, fmt.Sprintf(`cannot set "foo" through %s/network/setup-wifi: no matching rule`, s.devAccID)) +} + +func (s *confdbTestSuite) TestCanHookSetConfdb(c *C) { + s.state.Lock() + defer s.state.Unlock() + + mockHandler := hooktest.NewMockHandler() + chg := s.state.NewChange("test", "test change") + task := s.state.NewTask("test-task", "test task") + chg.AddTask(task) + + for _, tc := range []struct { + hook string + task *state.Task + expected bool + }{ + // we can set to modify transactions in read or write + {hook: "change-view-setup", task: task, expected: true}, + {hook: "query-view-setup", task: task, expected: true}, + // also to load data into a transaction + {hook: "load-view-setup", task: task, expected: true}, + // the other hooks cannot set + {hook: "save-view-setup", task: task, expected: false}, + {hook: "observe-view-setup", task: task, expected: false}, + // same for non-confdb hooks + {hook: "install", task: task, expected: false}, + {hook: "configure", task: task, expected: false}, + // helper expects the context to not be ephemeral + {hook: "change-view-setup", task: nil, expected: false}, + {hook: "query-view-setup", task: nil, expected: false}, + {hook: "load-view-setup", task: nil, expected: false}, + } { + setup := &hookstate.HookSetup{Snap: "test-snap", Hook: tc.hook} + ctx, err := hookstate.NewContext(tc.task, s.state, setup, mockHandler, "") + c.Assert(err, IsNil) + c.Check(confdbstate.CanHookSetConfdb(ctx), Equals, tc.expected) + } } func (s *confdbTestSuite) TestEnsureLoopLogging(c *C) { @@ -2161,7 +2172,7 @@ func (s *confdbTestSuite) TestGetTransactionWithSecretVisibility(c *C) { c.Assert(log[0], Matches, fmt.Sprintf(`.*cannot get "private" through %s/network/setup-wifi: unauthorized access`, s.devAccID)) } -func (s *confdbTestSuite) TestReadWithOngoingWrite(c *C) { +func (s *confdbTestSuite) TestAPIReadWithOngoingWrite(c *C) { view := s.dbSchema.View("setup-wifi") firstAccess := func(ctx context.Context) string { chgID, err := confdbstate.WriteConfdb(ctx, s.state, view, map[string]any{"ssid": "foo"}) @@ -2176,7 +2187,7 @@ func (s *confdbTestSuite) TestReadWithOngoingWrite(c *C) { s.testConcurrentAccess(c, firstAccess, secondAccess) } -func (s *confdbTestSuite) TestWriteWithOngoingWrite(c *C) { +func (s *confdbTestSuite) TestAPIWriteWithOngoingWrite(c *C) { view := s.dbSchema.View("setup-wifi") firstAccess := func(ctx context.Context) string { chgID, err := confdbstate.WriteConfdb(ctx, s.state, view, map[string]any{"ssid": "foo"}) @@ -2191,7 +2202,7 @@ func (s *confdbTestSuite) TestWriteWithOngoingWrite(c *C) { s.testConcurrentAccess(c, firstAccess, secondAccess) } -func (s *confdbTestSuite) TestWriteWithOngoingRead(c *C) { +func (s *confdbTestSuite) TestAPIWriteWithOngoingRead(c *C) { view := s.dbSchema.View("setup-wifi") firstAccess := func(ctx context.Context) string { chgID, err := confdbstate.ReadConfdb(ctx, s.state, view, []string{"ssid"}, nil, 0) @@ -2223,7 +2234,7 @@ func (s *confdbTestSuite) testConcurrentAccess(c *C, firstAccess, secondAccess a // testing helper closed when the access is about to block blockingChan := make(chan struct{}) - confdbstate.SetBlockingSignalChan(blockingChan) + confdbstate.SetBlockingSignal("wait-for-access", blockingChan) doneChan := make(chan struct{}) var secondChgID string @@ -2236,7 +2247,6 @@ func (s *confdbTestSuite) testConcurrentAccess(c *C, firstAccess, secondAccess a select { case <-blockingChan: // signals that the second access is going to block - break case <-time.After(testutil.HostScaledTimeout(2 * time.Second)): c.Fatal("expected access to block but timed out") } @@ -2250,7 +2260,6 @@ func (s *confdbTestSuite) testConcurrentAccess(c *C, firstAccess, secondAccess a select { case <-doneChan: // signals that the second access was unblocked and scheduled the operation - break case <-time.After(testutil.HostScaledTimeout(2 * time.Second)): c.Fatal("expected access to block but timed out") } @@ -2261,7 +2270,7 @@ func (s *confdbTestSuite) testConcurrentAccess(c *C, firstAccess, secondAccess a c.Assert(secondChgID, Not(Equals), "") } -func (s *confdbTestSuite) TestMultipleConcurrentReads(c *C) { +func (s *confdbTestSuite) TestAPIMultipleConcurrentReads(c *C) { s.state.Lock() defer s.state.Unlock() @@ -2281,9 +2290,8 @@ func (s *confdbTestSuite) TestMultipleConcurrentReads(c *C) { c.Assert(err, IsNil) c.Assert(secondChgID, Not(Equals), "") - // mock a pending write - waitChan := make(chan struct{}) - s.state.Cache("confdb-accesses-"+view.Schema().Account+"/network", []confdbstate.PendingAccess{{ + waitChan := make(chan struct{}, 1) + s.state.Cache("pending-confdb-"+view.Schema().Account+"/network", []confdbstate.Access{{ ID: "foo", AccessType: confdbstate.AccessType("write"), WaitChan: waitChan, @@ -2297,7 +2305,7 @@ func (s *confdbTestSuite) TestMultipleConcurrentReads(c *C) { select { case <-waitChan: // only one read tx close this otherwise the other would panic - case <-time.After(2 * time.Second): + case <-time.After(testutil.HostScaledTimeout(2 * time.Second)): c.Fatal("expected write to be unblocked but timed out") } @@ -2323,7 +2331,7 @@ func (s *confdbTestSuite) TestBlockingAccessIsCancelled(c *C) { // testing helper closed when the access is about to block blockingChan := make(chan struct{}) - confdbstate.SetBlockingSignalChan(blockingChan) + confdbstate.SetBlockingSignal("wait-for-access", blockingChan) doneChan := make(chan struct{}) var readErr error @@ -2335,7 +2343,6 @@ func (s *confdbTestSuite) TestBlockingAccessIsCancelled(c *C) { select { case <-blockingChan: // signals that the timed out read is done - break case <-time.After(testutil.HostScaledTimeout(2 * time.Second)): c.Fatal("expected access to block but timed out") } @@ -2343,14 +2350,13 @@ func (s *confdbTestSuite) TestBlockingAccessIsCancelled(c *C) { cancel() select { case <-doneChan: - break case <-time.After(testutil.HostScaledTimeout(2 * time.Second)): c.Fatal("expected access to block but timed out") } c.Assert(readErr, ErrorMatches, ".*timed out waiting for access") } -func (s *confdbTestSuite) TestBlockingAccessTimedOut(c *C) { +func (s *confdbTestSuite) TestAPIBlockingAccessTimedOut(c *C) { s.state.Lock() defer s.state.Unlock() @@ -2365,7 +2371,7 @@ func (s *confdbTestSuite) TestBlockingAccessTimedOut(c *C) { // testing helper closed when the access is about to block blockingChan := make(chan struct{}) - confdbstate.SetBlockingSignalChan(blockingChan) + confdbstate.SetBlockingSignal("wait-for-access", blockingChan) restore = confdbstate.MockDefaultWaitTimeout(time.Millisecond) defer restore() @@ -2379,14 +2385,13 @@ func (s *confdbTestSuite) TestBlockingAccessTimedOut(c *C) { select { case <-doneChan: - break case <-time.After(testutil.HostScaledTimeout(2 * time.Second)): c.Fatal("expected access to block but timed out") } c.Assert(readErr, ErrorMatches, ".*timed out waiting for access") } -func (s *confdbTestSuite) TestAccessDifferentConfdbIndependently(c *C) { +func (s *confdbTestSuite) TestAPIAccessDifferentConfdbIndependently(c *C) { s.state.Lock() defer s.state.Unlock() @@ -2401,7 +2406,7 @@ func (s *confdbTestSuite) TestAccessDifferentConfdbIndependently(c *C) { // testing helper closed when the access is about to block blockingChan := make(chan struct{}) - confdbstate.SetBlockingSignalChan(blockingChan) + confdbstate.SetBlockingSignal("wait-for-access", blockingChan) restore = confdbstate.MockDefaultWaitTimeout(time.Millisecond) defer restore() @@ -2413,7 +2418,6 @@ func (s *confdbTestSuite) TestAccessDifferentConfdbIndependently(c *C) { func (s *confdbTestSuite) TestFailedAccessUnblocksNextAccess(c *C) { s.state.Lock() - defer s.state.Unlock() // force the read/writes to fail due to missing custodian repo := interfaces.NewRepository() @@ -2421,6 +2425,7 @@ func (s *confdbTestSuite) TestFailedAccessUnblocksNextAccess(c *C) { view := s.dbSchema.View("setup-wifi") ctx := context.Background() + s.state.Unlock() var accErr error // mock ongoing read transaction and pending access @@ -2435,58 +2440,660 @@ func (s *confdbTestSuite) TestFailedAccessUnblocksNextAccess(c *C) { ongoingTxs[ref] = &confdbstate.ConfdbTransactions{ WriteTxID: "10", } + s.state.Lock() s.state.Set("confdb-ongoing-txs", ongoingTxs) - s.state.Cache("confdb-accesses-"+ref, nil) + s.state.Cache("pending-confdb-"+ref, nil) + s.state.Cache("scheduling-confdb-"+ref, nil) // testing helper closed when the access is about to block blockingChan := make(chan struct{}) - confdbstate.SetBlockingSignalChan(blockingChan) + confdbstate.SetBlockingSignal("wait-for-access", blockingChan) accDone := make(chan struct{}) go func() { accessFunc() + s.state.Unlock() close(accDone) }() select { case <-blockingChan: - case <-time.After(2 * time.Second): + case <-time.After(testutil.HostScaledTimeout(2 * time.Second)): c.Fatal("expected access to block but timed out") } // while the access is blocked mock another one coming in s.state.Lock() - accs := s.state.Cached("confdb-accesses-" + ref) + accs := s.state.Cached("pending-confdb-" + ref) c.Assert(accs, NotNil) - pending := accs.([]confdbstate.PendingAccess) + pending := accs.([]confdbstate.Access) c.Assert(pending, HasLen, 1) - // mock another pending access - waitChan := make(chan struct{}) - pending = append(pending, confdbstate.PendingAccess{ + waitChan := make(chan struct{}, 1) + s.endOngoingAccess(c, &confdbstate.Access{ ID: "foo", AccessType: confdbstate.AccessType("write"), WaitChan: waitChan, }) - s.state.Cache("confdb-accesses-"+ref, pending) s.state.Unlock() - // unblock the access we started - close(pending[0].WaitChan) - // the access we mocked should be unblocked select { case <-waitChan: - case <-time.After(2 * time.Second): + case <-time.After(testutil.HostScaledTimeout(2 * time.Second)): c.Fatal("expected next access to be unblocked but timed out") } // the access failed with the expected error select { case <-accDone: - case <-time.After(2 * time.Second): + case <-time.After(testutil.HostScaledTimeout(2 * time.Second)): c.Fatal("expected failed access to return but timed out") } c.Assert(accErr, ErrorMatches, ".*: no custodian snap connected") } } + +func (s *confdbTestSuite) testSnapctlConcurrentAccess(c *C, firstAccess accessFunc, secondAccess func()) { + s.state.Lock() + + s.setupConfdbScenario(c, map[string]confdbHooks{"custodian-snap": allHooks}, nil) + _, restore := s.mockConfdbHooks() + defer restore() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + firstAccess(ctx) + s.state.Unlock() + + // testing helper closed when the access is about to block + blockingChan := make(chan struct{}) + confdbstate.SetBlockingSignal("wait-for-access", blockingChan) + + doneChan := make(chan struct{}) + go func() { + secondAccess() + close(doneChan) + }() + + select { + case <-blockingChan: + // second access blocked waiting for its turn + case <-time.After(testutil.HostScaledTimeout(2 * time.Second)): + c.Fatal("expected access to block but timed out") + } + + // closed when the second access waits for the change to complete + blockingChan = make(chan struct{}) + confdbstate.SetBlockingSignal("wait-for-change-done", blockingChan) + + err := s.o.Settle(5 * time.Second) + c.Assert(err, IsNil) + + // once the first access completes the second access should be unblocked, scheduled + // and again while the change runs + select { + case <-blockingChan: + case <-time.After(testutil.HostScaledTimeout(5 * time.Second)): + c.Fatal("expected second access to block while change runs but timed out") + } + + // when the second access is ongoing and waiting for the change to end, the + // queues are empty + s.state.Lock() + txs, _, err := confdbstate.GetOngoingTxs(s.state, s.devAccID, "network") + s.state.Unlock() + c.Assert(err, IsNil) + c.Assert(txs.Pending, IsNil) + c.Assert(txs.Scheduling, IsNil) + + err = s.o.Settle(5 * time.Second) + c.Assert(err, IsNil) + + select { + case <-doneChan: + case <-time.After(testutil.HostScaledTimeout(5 * time.Second)): + c.Fatal("expected access to block but timed out") + } +} + +func (s *confdbTestSuite) TestSnapctlWriteOngoingRead(c *C) { + view := s.dbSchema.View("setup-wifi") + + firstAccess := func(ctx context.Context) string { + chgID, err := confdbstate.ReadConfdb(ctx, s.state, view, []string{"ssid"}, nil, 0) + c.Assert(err, IsNil) + return chgID + } + + mockHandler := hooktest.NewMockHandler() + setup := &hookstate.HookSetup{Snap: "test-snap", Hook: "change-view-setup"} + ctx, err := hookstate.NewContext(nil, s.state, setup, mockHandler, "") + c.Assert(err, IsNil) + + secondAccess := func() { + ctx.Lock() + err := confdbstate.WriteConfdbFromSnap(ctx, view, map[string]any{"ssid": "foo"}) + ctx.Unlock() + c.Assert(err, IsNil) + } + s.testSnapctlConcurrentAccess(c, firstAccess, secondAccess) +} + +func (s *confdbTestSuite) TestSnapctlReadOngoingWrite(c *C) { + view := s.dbSchema.View("setup-wifi") + + mockHandler := hooktest.NewMockHandler() + setup := &hookstate.HookSetup{Snap: "test-snap", Hook: "change-view-setup"} + ctx, err := hookstate.NewContext(nil, s.state, setup, mockHandler, "") + c.Assert(err, IsNil) + + firstAccess := func(ctx context.Context) string { + chgID, err := confdbstate.WriteConfdb(ctx, s.state, view, map[string]any{"ssid": "foo"}) + c.Assert(err, IsNil) + return chgID + } + + secondAccess := func() { + ctx.Lock() + _, err := confdbstate.ReadConfdbFromSnap(ctx, view, []string{"ssid"}, nil) + ctx.Unlock() + c.Assert(err, IsNil) + } + s.testSnapctlConcurrentAccess(c, firstAccess, secondAccess) +} + +func (s *confdbTestSuite) TestReadWithOngoingReadBlocksIfWriteIsPending(c *C) { + s.state.Lock() + defer s.state.Unlock() + + view := s.dbSchema.View("setup-wifi") + + // mock ongoing read transaction and pending access + ref := s.devAccID + "/network" + ongoingTxs := make(map[string]*confdbstate.ConfdbTransactions) + ongoingTxs[ref] = &confdbstate.ConfdbTransactions{ + ReadTxIDs: []string{"10"}, + } + s.state.Set("confdb-ongoing-txs", ongoingTxs) + s.state.Cache("pending-confdb-"+ref, []confdbstate.Access{{ + ID: "foo", + AccessType: confdbstate.AccessType("write"), + WaitChan: make(chan struct{}), + }}) + + // testing helper closed when the access is about to block + blockingChan := make(chan struct{}) + confdbstate.SetBlockingSignal("wait-for-access", blockingChan) + + ctx, cancel := context.WithCancel(context.Background()) + readDone := make(chan struct{}) + go func() { + _, err := confdbstate.ReadConfdb(ctx, s.state, view, []string{"ssid"}, nil, 0) + c.Assert(err, ErrorMatches, fmt.Sprintf("cannot read %s: timed out waiting for access", view.ID())) + close(readDone) + }() + + select { + case <-blockingChan: + case <-time.After(testutil.HostScaledTimeout(2 * time.Second)): + c.Fatal("expected access to block but timed out") + } + + // the read access released the lock and blocked so we have to re-lock + s.state.Lock() + pending, ok := s.state.Cached("pending-confdb-" + ref).([]confdbstate.Access) + s.state.Unlock() + c.Assert(ok, Equals, true) + c.Assert(pending, HasLen, 2) + c.Assert(pending[1].AccessType, Equals, confdbstate.AccessType("read")) + + // cancel the pending read access which should return an error and clean up + // its waiting channel from the pending queue + cancel() + + select { + case <-readDone: + // at this point the read returned and the state was re-locked + case <-time.After(testutil.HostScaledTimeout(2 * time.Second)): + c.Fatal("expected access to block but timed out") + } + + // check that cancelling an access cleans up the pending state + pending, ok = s.state.Cached("pending-confdb-" + ref).([]confdbstate.Access) + c.Assert(ok, Equals, true) + c.Assert(pending, HasLen, 1) + c.Assert(pending[0].AccessType, Equals, confdbstate.AccessType("write")) +} + +func (s *confdbTestSuite) TestSnapctlReadAndWriteUseHookTimeout(c *C) { + s.state.Lock() + s.setupConfdbScenario(c, map[string]confdbHooks{"custodian-snap": allHooks}, nil) + _, restore := s.mockConfdbHooks() + defer restore() + + view := s.dbSchema.View("setup-wifi") + + mockHandler := hooktest.NewMockHandler() + setup := &hookstate.HookSetup{Snap: "test-snap", Hook: "change-view-setup", Timeout: time.Microsecond} + ctx, err := hookstate.NewContext(nil, s.state, setup, mockHandler, "") + c.Assert(err, IsNil) + + ref := s.devAccID + "/network" + ongoingTxs := make(map[string]*confdbstate.ConfdbTransactions) + ongoingTxs[ref] = &confdbstate.ConfdbTransactions{ + WriteTxID: "10", + } + s.state.Set("confdb-ongoing-txs", ongoingTxs) + s.state.Unlock() + + ctx.Lock() + defer ctx.Unlock() + + _, err = confdbstate.ReadConfdbFromSnap(ctx, view, []string{"ssid"}, nil) + c.Assert(err, ErrorMatches, fmt.Sprintf("cannot read %s: timed out waiting for access", view.ID())) + + err = confdbstate.WriteConfdbFromSnap(ctx, view, map[string]any{"ssid": "foo"}) + c.Assert(err, ErrorMatches, fmt.Sprintf("cannot write %s: timed out waiting for access", view.ID())) +} + +func (s *confdbTestSuite) TestOngoingTxUnblocksMultiplePendingReads(c *C) { + s.state.Lock() + defer s.state.Unlock() + + s.setupConfdbScenario(c, map[string]confdbHooks{"custodian-snap": allHooks}, nil) + _, restore := s.mockConfdbHooks() + defer restore() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + view := s.dbSchema.View("setup-wifi") + chgID, err := confdbstate.WriteConfdb(ctx, s.state, view, map[string]any{"ssid": "foo"}) + c.Assert(err, IsNil) + + readOneChan, readTwoChan, writeChan := make(chan struct{}, 1), make(chan struct{}, 1), make(chan struct{}, 1) + s.state.Cache("pending-confdb-"+view.Schema().Account+"/network", []confdbstate.Access{ + { + ID: "foo", + AccessType: confdbstate.AccessType("read"), + WaitChan: readOneChan, + }, + { + ID: "bar", + AccessType: confdbstate.AccessType("read"), + WaitChan: readTwoChan, + }, + { + ID: "baz", + AccessType: confdbstate.AccessType("write"), + WaitChan: writeChan, + }, + }) + + s.state.Unlock() + err = s.o.Settle(5 * time.Second) + s.state.Lock() + c.Assert(err, IsNil) + + chg := s.state.Change(chgID) + c.Assert(chg.Status(), Equals, state.DoneStatus) + + // the running transaction unblocked the reads + select { + case <-readOneChan: + case <-time.After(testutil.HostScaledTimeout(2 * time.Second)): + c.Fatal("expected 1st read to be unblocked but timed out") + } + + select { + case <-readTwoChan: + case <-time.After(testutil.HostScaledTimeout(2 * time.Second)): + c.Fatal("expected 2nd read to be unblocked but timed out") + } + + // but not the write + select { + case <-writeChan: + c.Fatal("expected write not to have been unblocked") + case <-time.After(testutil.HostScaledTimeout(time.Millisecond)): + } +} + +func (s *confdbTestSuite) TestAPIConfdbErrorUnblocksNextAccess(c *C) { + s.state.Lock() + s.setupConfdbScenario(c, map[string]confdbHooks{"custodian-snap": allHooks}, nil) + s.state.Unlock() + + view := s.dbSchema.View("setup-wifi") + ref := view.Schema().Account + "/" + view.Schema().Name + ctx := context.Background() + + var accErr error + for _, accFunc := range []func(){ + func() { + _, accErr = confdbstate.WriteConfdb(ctx, s.state, view, map[string]any{"nonexistent": "value"}) + }, + func() { _, accErr = confdbstate.ReadConfdb(ctx, s.state, view, []string{"nonexistent"}, nil, 0) }, + } { + s.state.Lock() + // mock an ongoing write transaction so the next access blocks + s.state.Set("confdb-ongoing-txs", map[string]*confdbstate.ConfdbTransactions{ + ref: {WriteTxID: "10"}, + }) + s.state.Cache("pending-confdb-"+ref, nil) + s.state.Cache("scheduling-confdb-"+ref, nil) + + blockingChan := make(chan struct{}) + confdbstate.SetBlockingSignal("wait-for-access", blockingChan) + + doneChan := make(chan struct{}) + go func() { + accFunc() + s.state.Unlock() + close(doneChan) + }() + + select { + case <-blockingChan: + case <-time.After(testutil.HostScaledTimeout(2 * time.Second)): + c.Fatal("expected access to block but timed out") + } + + // the blocked access released the lock; set up the next pending access + s.state.Lock() + accs := s.state.Cached("pending-confdb-" + ref) + c.Assert(accs, NotNil) + pending := accs.([]confdbstate.Access) + c.Assert(pending, HasLen, 1) + + // clear the ongoing tx, queue another pending access, then unblock + nextWaitChan := make(chan struct{}, 1) + s.endOngoingAccess(c, &confdbstate.Access{ + ID: "next-access", + AccessType: confdbstate.AccessType("write"), + WaitChan: nextWaitChan, + }) + s.state.Unlock() + + // the access should fail and unblock the next pending access + select { + case <-nextWaitChan: + case <-time.After(testutil.HostScaledTimeout(2 * time.Second)): + c.Fatal("expected next access to be unblocked but timed out") + } + + select { + case <-doneChan: + case <-time.After(testutil.HostScaledTimeout(2 * time.Second)): + c.Fatal("expected failed write to return but timed out") + } + c.Assert(accErr, ErrorMatches, `.*no matching rule`) + } +} + +// endOngoingAccess can be used to simulate the termination of a mocked ongoing +// transaction. It unsets the ongoing tx in the state, unblocks the next pending +// accesses and moves them to processing. If a new pending access is provided, +// it's set in the state. +func (s *confdbTestSuite) endOngoingAccess(c *C, newPending *confdbstate.Access) { + txs, updateFunc, err := confdbstate.GetOngoingTxs(s.state, s.devAccID, "network") + c.Assert(err, IsNil) + defer updateFunc(txs) + + txs.ReadTxIDs = nil + txs.WriteTxID = "" + + err = confdbstate.MaybeUnblockAccesses(txs) + c.Assert(err, IsNil) + + if newPending != nil { + txs.Pending = append(txs.Pending, *newPending) + } +} + +func (s *confdbTestSuite) TestSnapctlConfdbErrorUnblocksNextAccess(c *C) { + // force the read/writes to fail due to missing custodian + s.state.Lock() + repo := interfaces.NewRepository() + ifacerepo.Replace(s.state, repo) + + view := s.dbSchema.View("setup-wifi") + ref := view.Schema().Account + "/" + view.Schema().Name + + mockHandler := hooktest.NewMockHandler() + setup := &hookstate.HookSetup{Snap: "test-snap", Hook: "install"} + t := s.state.NewTask("run-hook", "") + chg := s.state.NewChange("some-change", "") + chg.AddTask(t) + + hookCtx, err := hookstate.NewContext(t, s.state, setup, mockHandler, "") + c.Assert(err, IsNil) + s.state.Unlock() + + var accErr error + for _, accFunc := range []func(){ + func() { + _, accErr = confdbstate.ReadConfdbFromSnap(hookCtx, view, []string{"ssid"}, nil) + }, + func() { + accErr = confdbstate.WriteConfdbFromSnap(hookCtx, view, map[string]any{"ssid": "foo"}) + }, + } { + accErr = nil + + s.state.Lock() + s.state.Set("confdb-ongoing-txs", map[string]*confdbstate.ConfdbTransactions{ + ref: {WriteTxID: "10"}, + }) + s.state.Cache("pending-confdb-"+ref, nil) + s.state.Cache("scheduling-confdb-"+ref, nil) + + blockingChan := make(chan struct{}) + confdbstate.SetBlockingSignal("wait-for-access", blockingChan) + s.state.Unlock() + + accDone := make(chan struct{}) + go func() { + hookCtx.Lock() + accFunc() + hookCtx.Unlock() + close(accDone) + }() + + select { + case <-blockingChan: + case <-time.After(testutil.HostScaledTimeout(2 * time.Second)): + c.Fatal("expected access to block but timed out") + } + + // the blocked access released the lock; set up the next pending access + s.state.Lock() + accs := s.state.Cached("pending-confdb-" + ref) + c.Assert(accs, NotNil) + pending := accs.([]confdbstate.Access) + c.Assert(pending, HasLen, 1) + + // clear the ongoing tx, queue another pending access, then unblock + nextWaitChan := make(chan struct{}, 1) + s.endOngoingAccess(c, &confdbstate.Access{ + ID: "next-access", + AccessType: confdbstate.AccessType("write"), + WaitChan: nextWaitChan, + }) + s.state.Unlock() + + // the failed access should unblock the next pending access + select { + case <-nextWaitChan: + case <-time.After(testutil.HostScaledTimeout(2 * time.Second)): + c.Fatal("expected next access to be unblocked but timed out") + } + + select { + case <-accDone: + case <-time.After(testutil.HostScaledTimeout(2 * time.Second)): + c.Fatal("expected failed access to return but timed out") + } + c.Assert(accErr, ErrorMatches, ".*: no custodian snap connected") + } +} + +func (s *confdbTestSuite) TestReadConfdbFromSnapNoHooksToRun(c *C) { + s.state.Lock() + + // the custodian snap has no hooks, so no tasks should be scheduled + custodians := map[string]confdbHooks{"custodian-snap": noHooks} + s.setupConfdbScenario(c, custodians, nil) + + // write some value for the get to read + bag := confdb.NewJSONDatabag() + err := bag.Set(parsePath(c, "wifi.ssid"), "foo") + c.Assert(err, IsNil) + + view := s.dbSchema.View("setup-wifi") + ref := view.Schema().Account + "/" + view.Schema().Name + s.state.Set("confdb-databags", map[string]map[string]confdb.JSONDatabag{s.devAccID: {"network": bag}}) + + mockHandler := hooktest.NewMockHandler() + setup := &hookstate.HookSetup{Snap: "test-snap", Hook: "change-view-setup"} + hookCtx, err := hookstate.NewContext(nil, s.state, setup, mockHandler, "") + c.Assert(err, IsNil) + + // simulate an ongoing write transaction so the read blocks + s.state.Set("confdb-ongoing-txs", map[string]*confdbstate.ConfdbTransactions{ + ref: {WriteTxID: "10"}, + }) + + blockingChan := make(chan struct{}) + confdbstate.SetBlockingSignal("wait-for-access", blockingChan) + s.state.Unlock() + + var tx *confdbstate.Transaction + var readErr error + doneChan := make(chan struct{}) + go func() { + hookCtx.Lock() + tx, readErr = confdbstate.ReadConfdbFromSnap(hookCtx, view, []string{"ssid"}, nil) + hookCtx.Unlock() + close(doneChan) + }() + + select { + case <-blockingChan: + case <-time.After(testutil.HostScaledTimeout(2 * time.Second)): + c.Fatal("expected access to block but timed out") + } + + // clear the ongoing tx, queue another pending access, then unblock + nextWaitChan := make(chan struct{}, 1) + s.state.Lock() + s.endOngoingAccess(c, &confdbstate.Access{ + ID: "next-write", + AccessType: confdbstate.AccessType("write"), + WaitChan: nextWaitChan, + }) + s.state.Unlock() + + // the no-hooks read path should unblock the next pending access + select { + case <-nextWaitChan: + case <-time.After(testutil.HostScaledTimeout(2 * time.Second)): + c.Fatal("expected next access to be unblocked but timed out") + } + + select { + case <-doneChan: + case <-time.After(testutil.HostScaledTimeout(2 * time.Second)): + c.Fatal("expected read to complete but timed out") + } + + c.Assert(readErr, IsNil) + c.Assert(tx, NotNil) + + s.state.Lock() + defer s.state.Unlock() + + // no tasks were scheduled because there are no hooks to run + c.Assert(s.state.Changes(), HasLen, 0) + + val, err := tx.Get(parsePath(c, "wifi.ssid"), nil) + c.Assert(err, IsNil) + c.Assert(val, Equals, "foo") +} + +func (s *confdbTestSuite) TestAPIBlockingAccessTimedOutRacesWithUnblock(c *C) { + s.state.Lock() + defer s.state.Unlock() + + s.setupConfdbScenario(c, map[string]confdbHooks{"custodian-snap": allHooks}, nil) + _, restore := s.mockConfdbHooks() + defer restore() + + view := s.dbSchema.View("setup-wifi") + ref := view.Schema().Account + "/" + view.Schema().Name + // simulate an ongoing write transaction so the read blocks + s.state.Set("confdb-ongoing-txs", map[string]*confdbstate.ConfdbTransactions{ + ref: {WriteTxID: "10"}, + }) + + blockingChan := make(chan struct{}) + confdbstate.SetBlockingSignal("wait-for-access", blockingChan) + + ctx, cancel := context.WithCancel(context.Background()) + doneChan := make(chan struct{}) + var cancelErr error + go func() { + _, cancelErr = confdbstate.WriteConfdb(ctx, s.state, view, map[string]any{"ssid": "foo"}) + close(doneChan) + }() + + select { + case <-blockingChan: + case <-time.After(testutil.HostScaledTimeout(2 * time.Second)): + c.Fatal("expected access to block but timed out") + } + + // mock a time out/cancel racing with an unblock + s.state.Lock() + cancel() + waitChan := make(chan struct{}, 1) + // in order to mock a race, we need to cancel the context and mock that another + // goroutine unblocked the channel and removed it. We won't actually unblock + // the channel otherwise we couldn't be sure which case the select would pick + txs, updateFunc, err := confdbstate.GetOngoingTxs(s.state, s.devAccID, "network") + c.Assert(err, IsNil) + c.Assert(txs.Pending, HasLen, 1) + c.Assert(txs.Pending[0].AccessType, Equals, confdbstate.AccessType("write")) + + // mock another goroutine unblocking the pending write + txs.WriteTxID = "" + txs.Scheduling = txs.Pending + txs.Pending = []confdbstate.Access{{ + ID: "next-read", + AccessType: confdbstate.AccessType("read"), + WaitChan: waitChan, + }} + updateFunc(txs) + s.state.Unlock() + + select { + case <-doneChan: + case <-time.After(testutil.HostScaledTimeout(2 * time.Second)): + c.Fatal("expected access to block but timed out") + } + c.Assert(cancelErr, ErrorMatches, ".*timed out waiting for access") + + // even though the pending access was already unblocked, the time out/cancel + // still cleaned up its state and unblocked the next access + cached := s.state.Cached("scheduling-confdb-" + ref).([]confdbstate.Access) + c.Assert(cached, HasLen, 1) + c.Assert(cached[0].AccessType, Equals, confdbstate.AccessType("read")) + + select { + case <-waitChan: + case <-time.After(testutil.HostScaledTimeout(2 * time.Second)): + c.Fatal("expected access to block but timed out") + } +} diff --git a/overlord/confdbstate/export_test.go b/overlord/confdbstate/export_test.go index 3ed2d493d85..a040bc8ebfc 100644 --- a/overlord/confdbstate/export_test.go +++ b/overlord/confdbstate/export_test.go @@ -39,15 +39,10 @@ var ( type ( ConfdbTransactions = confdbTransactions - PendingAccess = pendingAccess + Access = access AccessType = accessType ) -const ( - CommitEdge = commitEdge - ClearTxEdge = clearTxEdge -) - func ChangeViewHandlerGenerator(ctx *hookstate.Context) hookstate.Handler { return &changeViewHandler{ctx: ctx} } @@ -96,6 +91,21 @@ func MockDefaultWaitTimeout(dur time.Duration) func() { } } -func SetBlockingSignalChan(signalChan chan struct{}) { - blockingSignalChan = signalChan +func SetBlockingSignal(key string, signalChan chan struct{}) { + if blockingSignals == nil { + blockingSignals = make(map[string]chan struct{}) + } + blockingSignals[key] = signalChan +} + +func ResetBlockingSignals() { + blockingSignals = nil +} + +func MaybeUnblockAccesses(txs *confdbTransactions) error { + return maybeUnblockAccesses(txs) +} + +func GetOngoingTxs(st *state.State, account, schemaName string) (ongoingTxs *confdbTransactions, updateTxStateFunc func(*confdbTransactions), err error) { + return getOngoingTxs(st, account, schemaName) } diff --git a/overlord/hookstate/ctlcmd/ctlcmd.go b/overlord/hookstate/ctlcmd/ctlcmd.go index ce7354f1205..2b7774490d2 100644 --- a/overlord/hookstate/ctlcmd/ctlcmd.go +++ b/overlord/hookstate/ctlcmd/ctlcmd.go @@ -148,7 +148,7 @@ func (f ForbiddenCommandError) Error() string { // nonRootAllowed lists the commands that can be performed even when snapctl // is invoked not by root. -var nonRootAllowed = []string{"get", "services", "set-health", "is-connected", "system-mode", "refresh", "model", "version"} +var nonRootAllowed = []string{"get", "services", "set-health", "is-connected", "system-mode", "refresh", "model", "version", "is-ready"} // Run runs the requested command. func Run(context *hookstate.Context, args []string, uid uint32, features []string) (stdout, stderr []byte, err error) { diff --git a/overlord/hookstate/ctlcmd/export_test.go b/overlord/hookstate/ctlcmd/export_test.go index b91f0fea492..bba0e79095f 100644 --- a/overlord/hookstate/ctlcmd/export_test.go +++ b/overlord/hookstate/ctlcmd/export_test.go @@ -22,6 +22,7 @@ package ctlcmd import ( "context" "errors" + "time" "github.com/snapcore/snapd/asserts" "github.com/snapcore/snapd/asserts/snapasserts" @@ -51,6 +52,8 @@ var ( ) type KmodCommand = kmodCommand +type IsReadyCommand = isReadyCommand +type ChangeRateLimitKey = changeRateLimitKey func MockKmodCheckConnection(f func(*hookstate.Context, string, []string) error) (restore func()) { r := testutil.Backup(&kmodCheckConnection) @@ -187,11 +190,11 @@ func MockNewStatusDecorator(f func(ctx context.Context, isGlobal bool, uid strin return restore } -func MockConfdbstateTransactionForSet(f func(*hookstate.Context, *state.State, *confdb.View) (*confdbstate.Transaction, confdbstate.CommitTxFunc, error)) (restore func()) { - old := confdbstateTransactionForSet - confdbstateTransactionForSet = f +func MockConfdbstateWriteConfdb(f func(*hookstate.Context, *confdb.View, map[string]any) error) (restore func()) { + old := confdbstateWriteConfdb + confdbstateWriteConfdb = f return func() { - confdbstateTransactionForSet = old + confdbstateWriteConfdb = old } } @@ -203,10 +206,17 @@ func MockConfdbstateGetView(f func(st *state.State, account, confdbName, viewNam } } -func MockConfdbstateTransactionForGet(f func(*hookstate.Context, *confdb.View, []string, map[string]any) (*confdbstate.Transaction, error)) (restore func()) { - old := confdbstateTransactionForGet - confdbstateTransactionForGet = f +func MockConfdbstateReadConfdb(f func(*hookstate.Context, *confdb.View, []string, map[string]any) (*confdbstate.Transaction, error)) (restore func()) { + old := confdbstateReadConfdb + confdbstateReadConfdb = f return func() { - confdbstateTransactionForGet = old + confdbstateReadConfdb = old } } + +// TODO:GOVERSION: use time bubbles once project is updated to Go 1.26 +func MockTimeAfter(f func(time.Duration) <-chan time.Time) (restore func()) { + old := timeAfter + timeAfter = f + return func() { timeAfter = old } +} diff --git a/overlord/hookstate/ctlcmd/get.go b/overlord/hookstate/ctlcmd/get.go index 0f84ee39f4a..aaf486871d8 100644 --- a/overlord/hookstate/ctlcmd/get.go +++ b/overlord/hookstate/ctlcmd/get.go @@ -42,8 +42,8 @@ import ( ) var ( - confdbstateGetView = confdbstate.GetView - confdbstateTransactionForGet = confdbstate.GetTransactionForSnapctlGet + confdbstateGetView = confdbstate.GetView + confdbstateReadConfdb = confdbstate.ReadConfdbFromSnap ) type getCommand struct { @@ -448,7 +448,8 @@ func (c *getCommand) getConfdbValues(ctx *hookstate.Context, plugName string, re return err } - tx, err := confdbstateTransactionForGet(ctx, view, requests, constraints) + // TODO: add --wait-for timeout to options and cache in hookstate context + tx, err := confdbstateReadConfdb(ctx, view, requests, constraints) if err != nil { return err } diff --git a/overlord/hookstate/ctlcmd/get_test.go b/overlord/hookstate/ctlcmd/get_test.go index faf0ffa6b9c..05f9e328896 100644 --- a/overlord/hookstate/ctlcmd/get_test.go +++ b/overlord/hookstate/ctlcmd/get_test.go @@ -634,7 +634,7 @@ func (s *confdbSuite) TestConfdbGetSingleView(c *C) { c.Assert(err, IsNil) s.state.Unlock() - restore := ctlcmd.MockConfdbstateTransactionForGet(func(ctx *hookstate.Context, view *confdb.View, requests []string, _ map[string]any) (*confdbstate.Transaction, error) { + restore := ctlcmd.MockConfdbstateReadConfdb(func(ctx *hookstate.Context, view *confdb.View, requests []string, _ map[string]any) (*confdbstate.Transaction, error) { c.Assert(requests, DeepEquals, []string{"ssid"}) c.Assert(view.Schema().Account, Equals, s.devAccID) c.Assert(view.Schema().Name, Equals, "network") @@ -658,7 +658,7 @@ func (s *confdbSuite) TestConfdbGetManyViews(c *C) { c.Assert(err, IsNil) s.state.Unlock() - restore := ctlcmd.MockConfdbstateTransactionForGet(func(ctx *hookstate.Context, view *confdb.View, requests []string, _ map[string]any) (*confdbstate.Transaction, error) { + restore := ctlcmd.MockConfdbstateReadConfdb(func(ctx *hookstate.Context, view *confdb.View, requests []string, _ map[string]any) (*confdbstate.Transaction, error) { c.Assert(requests, DeepEquals, []string{"ssid", "password"}) c.Assert(view.Schema().Account, Equals, s.devAccID) c.Assert(view.Schema().Name, Equals, "network") @@ -687,7 +687,7 @@ func (s *confdbSuite) TestConfdbGetNoRequest(c *C) { c.Assert(err, IsNil) s.state.Unlock() - restore := ctlcmd.MockConfdbstateTransactionForGet(func(ctx *hookstate.Context, view *confdb.View, requests []string, _ map[string]any) (*confdbstate.Transaction, error) { + restore := ctlcmd.MockConfdbstateReadConfdb(func(ctx *hookstate.Context, view *confdb.View, requests []string, _ map[string]any) (*confdbstate.Transaction, error) { c.Assert(requests, IsNil) c.Assert(view.Schema().Account, Equals, s.devAccID) c.Assert(view.Schema().Name, Equals, "network") @@ -850,7 +850,7 @@ func (s *confdbSuite) TestConfdbGetPrevious(c *C) { err = tx.Set(parsePath(c, "wifi.ssid"), "bar") c.Assert(err, IsNil) - restore := ctlcmd.MockConfdbstateTransactionForGet(func(*hookstate.Context, *confdb.View, []string, map[string]any) (*confdbstate.Transaction, error) { + restore := ctlcmd.MockConfdbstateReadConfdb(func(*hookstate.Context, *confdb.View, []string, map[string]any) (*confdbstate.Transaction, error) { return tx, nil }) defer restore() @@ -1017,7 +1017,7 @@ func (s *confdbSuite) TestConfdbAccessUnconnectedPlug(c *C) { err = tx.Set(parsePath(c, "wifi.ssid"), "foo") c.Assert(err, IsNil) - restore := ctlcmd.MockConfdbstateTransactionForGet(func(*hookstate.Context, *confdb.View, []string, map[string]any) (*confdbstate.Transaction, error) { + restore := ctlcmd.MockConfdbstateReadConfdb(func(*hookstate.Context, *confdb.View, []string, map[string]any) (*confdbstate.Transaction, error) { c.Fatal("should not allow access to confdb") return tx, nil }) @@ -1077,7 +1077,7 @@ func (s *confdbSuite) TestConfdbDefaultIfNoData(c *C) { err = tx.Set(parsePath(c, "wifi.ssid"), "foo") c.Assert(err, IsNil) - restore := ctlcmd.MockConfdbstateTransactionForGet(func(*hookstate.Context, *confdb.View, []string, map[string]any) (*confdbstate.Transaction, error) { + restore := ctlcmd.MockConfdbstateReadConfdb(func(*hookstate.Context, *confdb.View, []string, map[string]any) (*confdbstate.Transaction, error) { return tx, nil }) defer restore() @@ -1098,7 +1098,7 @@ func (s *confdbSuite) TestConfdbDefaultNoFallbackIfTyped(c *C) { err = tx.Set(parsePath(c, "wifi.ssid"), "foo") c.Assert(err, IsNil) - restore := ctlcmd.MockConfdbstateTransactionForGet(func(*hookstate.Context, *confdb.View, []string, map[string]any) (*confdbstate.Transaction, error) { + restore := ctlcmd.MockConfdbstateReadConfdb(func(*hookstate.Context, *confdb.View, []string, map[string]any) (*confdbstate.Transaction, error) { return tx, nil }) defer restore() @@ -1117,7 +1117,7 @@ func (s *confdbSuite) TestConfdbDefaultWithOtherFlags(c *C) { tx, err := confdbstate.NewTransaction(s.state, s.devAccID, "network") c.Assert(err, IsNil) - restore := ctlcmd.MockConfdbstateTransactionForGet(func(*hookstate.Context, *confdb.View, []string, map[string]any) (*confdbstate.Transaction, error) { + restore := ctlcmd.MockConfdbstateReadConfdb(func(*hookstate.Context, *confdb.View, []string, map[string]any) (*confdbstate.Transaction, error) { return tx, nil }) defer restore() @@ -1191,7 +1191,7 @@ func (s *confdbSuite) TestConfdbGetWithConstraints(c *C) { s.state.Unlock() var gotConstraints map[string]any - restore := ctlcmd.MockConfdbstateTransactionForGet(func(_ *hookstate.Context, _ *confdb.View, _ []string, constraints map[string]any) (*confdbstate.Transaction, error) { + restore := ctlcmd.MockConfdbstateReadConfdb(func(_ *hookstate.Context, _ *confdb.View, _ []string, constraints map[string]any) (*confdbstate.Transaction, error) { gotConstraints = constraints return tx, nil }) @@ -1315,7 +1315,7 @@ func (s *confdbSuite) TestConfdbGetTypedConstraints(c *C) { s.state.Unlock() var gotConstraints map[string]any - restore := ctlcmd.MockConfdbstateTransactionForGet(func(_ *hookstate.Context, _ *confdb.View, _ []string, constraints map[string]any) (*confdbstate.Transaction, error) { + restore := ctlcmd.MockConfdbstateReadConfdb(func(_ *hookstate.Context, _ *confdb.View, _ []string, constraints map[string]any) (*confdbstate.Transaction, error) { gotConstraints = constraints return tx, nil }) @@ -1354,7 +1354,7 @@ func (s *confdbSuite) TestConfdbGetSecretVisibility(c *C) { c.Assert(err, IsNil) s.state.Unlock() - restore := ctlcmd.MockConfdbstateTransactionForGet(func(ctx *hookstate.Context, view *confdb.View, requests []string, _ map[string]any) (*confdbstate.Transaction, error) { + restore := ctlcmd.MockConfdbstateReadConfdb(func(ctx *hookstate.Context, view *confdb.View, requests []string, _ map[string]any) (*confdbstate.Transaction, error) { c.Assert(requests, DeepEquals, []string{"password"}) c.Assert(view.Schema().Account, Equals, s.devAccID) c.Assert(view.Schema().Name, Equals, "network") diff --git a/overlord/hookstate/ctlcmd/helpers.go b/overlord/hookstate/ctlcmd/helpers.go index fe2c39572a9..f5178178e8d 100644 --- a/overlord/hookstate/ctlcmd/helpers.go +++ b/overlord/hookstate/ctlcmd/helpers.go @@ -48,6 +48,8 @@ var ( snapstateRemoveComponents = snapstate.RemoveComponents ) +var timeAfter = time.After + var ( serviceControlChangeKind = swfeats.RegisterChangeKind("service-control") snapctlInstallChangeKind = swfeats.RegisterChangeKind("snapctl-install") @@ -61,6 +63,8 @@ func init() { } } +const snapctlDebounceWindow = 200 * time.Millisecond + // finalSeedTask is the last task that should run during seeding. This is used // in the special handling of the "seed" change, which requires that we // introspect the change for this specific task. Finding this task allows us to @@ -451,6 +455,7 @@ func runSnapManagementCommand(hctx *hookstate.Context, cmd managementCommand) er chg := st.NewChange(changeKind, fmt.Sprintf("%s components %v for snap %s", cmdVerb, cmd.components, hctx.InstanceName())) + chg.Set("initiated-by-snap", hctx.InstanceName()) for _, ts := range tss { chg.AddAll(ts) } @@ -491,6 +496,140 @@ func jsonRaw(v any) *json.RawMessage { return &raw } +type changeRateLimitKey struct { + ChangeID string +} + +// isReady checks if the change is ready, if it is, it returns the status, otherwise state.DoingStatus. +func isReady(hctx *hookstate.Context, changeID string) (state.Status, error) { + callerSnapName := hctx.InstanceName() + + st := hctx.State() + st.Lock() + defer st.Unlock() + + chg := st.Change(changeID) + + if chg == nil { + return state.DefaultStatus, fmt.Errorf("change %q not found", changeID) + } + + var initiatorSnapName string + err := chg.Get("initiated-by-snap", &initiatorSnapName) + if err != nil { + return state.DefaultStatus, fmt.Errorf("change %q not found", changeID) + } + + if initiatorSnapName != callerSnapName { + return state.DefaultStatus, fmt.Errorf("change %q not found", changeID) + } + + wait, err := rateLimit(st, changeID, snapctlDebounceWindow) + if err != nil { + return state.DefaultStatus, err + } + + return unlockAndWaitForStatus(st, chg, wait), nil +} + +// unlockAndWaitForStatus unlocks the state and waits for the change to be ready. +// The lock must be held prior to calling, and will be re-acquired before returning. +// Returns doingStatus if the change is still in progress, otherwise returns the final +// status of the change. +func unlockAndWaitForStatus(st *state.State, chg *state.Change, wait time.Duration) state.Status { + st.Unlock() + // note: we cannot defer the re-lock, since we must re-lock prior to + // calculating the return value in some branches. + + ready := chg.Ready() + + // The check ensures that both select cases aren't true immediately. + if wait <= 0 { + select { + // use default so the channel is prioritized. + case <-ready: + st.Lock() + return chg.Status() + default: + st.Lock() + return state.DoingStatus + } + } + + // Because the wait could've been > 0, the last select between a closed ready channel + // and a timer.After channel would've be racy. + select { + case <-ready: + case <-timeAfter(wait): + st.Lock() + return state.DoingStatus + } + + st.Lock() + return chg.Status() +} + +// rateLimit returns the amount of time that should be waited before accessing +// this change via snapctl. Internally, data associated with the change is +// cached so that all access to the change shares the same rate limit. +// The lock must be acquired before calling, as it modifies the state object. +func rateLimit(st *state.State, changeID string, rate time.Duration) (wait time.Duration, err error) { + now := time.Now() + + accessed, err := changeAccessedAt(st, changeID) + if err != nil { + return 0, err + } + + // first time through, we just set the change access to now. next request + // must wait at least "rate" duration before access. + if accessed.IsZero() { + setChangeAccessedAt(st, now, changeID) + return 0, nil + } + + durationSinceLastAccess := now.Sub(accessed) + + // user waited on their own, no waiting needed. next access will require + // waiting at least "rate" duration. + if durationSinceLastAccess >= rate { + setChangeAccessedAt(st, now, changeID) + return 0, nil + } + + // user needs to wait a bit still. note that durationSinceLastAccess might + // be negative, since "accessed" could be in the future. this can happen + // when there are multiple requests in parallel, within a duration less than + // "rate". + wait = rate - durationSinceLastAccess + + // current request must wait. next request must wait this amount of time, + // plus at least "rate" duration. + setChangeAccessedAt(st, now.Add(wait), changeID) + + return wait, nil +} + +func changeAccessedAt(st *state.State, changeID string) (time.Time, error) { + key := changeRateLimitKey{ChangeID: changeID} + accessedAt := st.Cached(key) + if accessedAt == nil { + return time.Time{}, nil + } + + accessedNano, ok := accessedAt.(int64) + if !ok { + return time.Time{}, fmt.Errorf("error: invalid type (%T) for access time", accessedAt) + } + + return time.Unix(0, accessedNano), nil +} + +func setChangeAccessedAt(st *state.State, accessed time.Time, changeID string) { + key := changeRateLimitKey{ChangeID: changeID} + st.Cache(key, accessed.UnixNano()) +} + // getAttribute unmarshals into result the value of the provided key from attributes map. // If the key does not exist, an error of type *NoAttributeError is returned. // The provided key may be formed as a dotted key path through nested maps. diff --git a/overlord/hookstate/ctlcmd/is_ready.go b/overlord/hookstate/ctlcmd/is_ready.go new file mode 100644 index 00000000000..9289a1cc5d9 --- /dev/null +++ b/overlord/hookstate/ctlcmd/is_ready.go @@ -0,0 +1,89 @@ +// -*- Mode: Go; indent-tabs-mode: t -*- + +/* + * Copyright (C) 2026 Canonical Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 3 as + * published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +package ctlcmd + +import ( + "fmt" + + "github.com/snapcore/snapd/i18n" + "github.com/snapcore/snapd/overlord/state" +) + +type isReadyCommand struct { + baseCommand +} + +const ( + changeReadyExitCode = iota + changeNotReadyExitCode + changeUnsuccessfulExitCode + otherErrorExitCode +) + +var shortIsReadyHelp = i18n.G(`Return the status of the associated change id.`) +var longIsReadyHelp = i18n.G(` +The is-ready command is used to query the status of change ids that are returned +by asynchronous snapctl commands. + +$ snapctl is-ready + 0: change completed successfully (Done) + 1: change is not ready + 2: change is ready but did not complete successfully (Undone, Error, Hold) + 3: other errors (invalid change id, permissions error) +stdout: empty, exit code conveys change readiness +stderr: empty for exit codes 0 and 1. Contains relevant errors for exit codes 2 and 3. +`) + +func init() { + addCommand("is-ready", shortIsReadyHelp, longIsReadyHelp, func() command { + return &isReadyCommand{} + }) +} + +func (c *isReadyCommand) Execute(args []string) error { + ctx, err := c.ensureContext() + if err != nil { + return err + } + + if len(args) != 1 { + return fmt.Errorf("invalid number of arguments: expected 1, got %d", len(args)) + } + + changeID := args[0] + + ready, err := isReady(ctx, changeID) + + if err != nil { + fmt.Fprint(c.stderr, err.Error()) + return &UnsuccessfulError{ExitCode: otherErrorExitCode} + } + + if !ready.Ready() { + return &UnsuccessfulError{ExitCode: changeNotReadyExitCode} + } + + if ready != state.DoneStatus { + fmt.Fprintf(c.stderr, "change finished with status %s", ready) + return &UnsuccessfulError{ExitCode: changeUnsuccessfulExitCode} + } + + return nil +} diff --git a/overlord/hookstate/ctlcmd/is_ready_test.go b/overlord/hookstate/ctlcmd/is_ready_test.go new file mode 100644 index 00000000000..489548bce73 --- /dev/null +++ b/overlord/hookstate/ctlcmd/is_ready_test.go @@ -0,0 +1,249 @@ +// -*- Mode: Go; indent-tabs-mode: t -*- + +/* + * Copyright (C) 2026 Canonical Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 3 as + * published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +package ctlcmd_test + +import ( + "time" + + . "gopkg.in/check.v1" + + "github.com/snapcore/snapd/dirs" + "github.com/snapcore/snapd/overlord/hookstate" + "github.com/snapcore/snapd/overlord/hookstate/ctlcmd" + "github.com/snapcore/snapd/overlord/hookstate/hooktest" + "github.com/snapcore/snapd/overlord/state" + "github.com/snapcore/snapd/snap" + "github.com/snapcore/snapd/testutil" +) + +type isReadySuite struct { + testutil.BaseTest + mockHandler *hooktest.MockHandler +} + +var _ = Suite(&isReadySuite{}) + +func (s *isReadySuite) SetUpTest(c *C) { + s.BaseTest.SetUpTest(c) + dirs.SetRootDir(c.MkDir()) + s.AddCleanup(func() { dirs.SetRootDir("/") }) + s.mockHandler = hooktest.NewMockHandler() +} + +// setupChangeAndContext creates a state, a change (with an optional initiator), +// and a non-ephemeral hook context for "test-snap". +func (s *isReadySuite) setupChangeAndContext(c *C, taskStatus state.Status, initiatorSnap string) (*state.State, *hookstate.Context, string) { + st := state.New(nil) + st.Lock() + defer st.Unlock() + + chg := st.NewChange("snapctl-install", "install via snapctl") + task := st.NewTask("test-task", "test task") + chg.AddTask(task) + + if initiatorSnap != "" { + chg.Set("initiated-by-snap", initiatorSnap) + } + + task.SetStatus(taskStatus) + + setup := &hookstate.HookSetup{Snap: "test-snap", Revision: snap.R(1), Hook: "install"} + ctx, err := hookstate.NewContext(task, st, setup, s.mockHandler, "") + c.Assert(err, IsNil) + + return st, ctx, chg.ID() +} + +func (s *isReadySuite) TestIsReadyNoContext(c *C) { + _, _, err := ctlcmd.Run(nil, []string{"is-ready", "1"}, 0, nil) + c.Assert(err, ErrorMatches, `cannot invoke snapctl operation commands.*from outside of a snap`) +} + +func (s *isReadySuite) TestIsReadyArgCount(c *C) { + _, ctx, _ := s.setupChangeAndContext(c, state.DoneStatus, "test-snap") + _, _, err := ctlcmd.Run(ctx, []string{"is-ready"}, 0, nil) + c.Assert(err, ErrorMatches, `invalid number of arguments: expected 1, got 0`) + + _, _, err = ctlcmd.Run(ctx, []string{"is-ready", "1", "extra-arg"}, 0, nil) + c.Assert(err, ErrorMatches, `invalid number of arguments: expected 1, got 2`) +} + +func (s *isReadySuite) TestIsReadyChangeNotFound(c *C) { + _, ctx, _ := s.setupChangeAndContext(c, state.DoneStatus, "") + _, stderr, err := ctlcmd.Run(ctx, []string{"is-ready", "nonexistent-id"}, 0, nil) + c.Assert(err, DeepEquals, &ctlcmd.UnsuccessfulError{ExitCode: 3}) + c.Check(string(stderr), Matches, `change "nonexistent-id" not found`) +} + +func (s *isReadySuite) TestIsReadyLogic(c *C) { + var logicTests = []struct { + taskStatus state.Status + initiatorSnap string // empty = don't set initiated-by-snap on the change + errValue error // if set, expect err to deep equal this value + expectedOut string + expectedStderr string // if set, checked as regexp match against stderr + }{ + { + taskStatus: state.DoneStatus, + errValue: &ctlcmd.UnsuccessfulError{ExitCode: 3}, + expectedStderr: `change .* not found`, + }, + { + taskStatus: state.DoneStatus, + initiatorSnap: "other-snap", // different from context snap "test-snap" + errValue: &ctlcmd.UnsuccessfulError{ExitCode: 3}, + expectedStderr: `change .* not found`, + }, + { + taskStatus: state.DoneStatus, + initiatorSnap: "test-snap", + }, + { + taskStatus: state.DoingStatus, + initiatorSnap: "test-snap", + errValue: &ctlcmd.UnsuccessfulError{ExitCode: 1}, + }, + { + taskStatus: state.ErrorStatus, + initiatorSnap: "test-snap", + errValue: &ctlcmd.UnsuccessfulError{ExitCode: 2}, + expectedStderr: `change finished with status Error`, + }, + { + taskStatus: state.HoldStatus, + initiatorSnap: "test-snap", + errValue: &ctlcmd.UnsuccessfulError{ExitCode: 2}, + expectedStderr: `change finished with status Hold`, + }, + } + + for _, tt := range logicTests { + _, ctx, changeID := s.setupChangeAndContext(c, tt.taskStatus, tt.initiatorSnap) + stdout, stderr, err := ctlcmd.Run(ctx, []string{"is-ready", changeID}, 0, nil) + if tt.errValue != nil { + c.Assert(err, DeepEquals, tt.errValue) + } else { + c.Assert(err, IsNil) + } + c.Check(string(stdout), Equals, tt.expectedOut) + if tt.expectedStderr != "" { + c.Check(string(stderr), Matches, tt.expectedStderr) + } else { + c.Check(string(stderr), Equals, "") + } + } +} + +// Rate-limiting tests +func (s *isReadySuite) rateLimitSetup(c *C, taskStatus state.Status, lastAccessedTime any) (*hookstate.Context, string) { + st := state.New(nil) + st.Lock() + defer st.Unlock() + + chg := st.NewChange("snapctl-install", "install via snapctl") + task := st.NewTask("test-task", "test task") + chg.AddTask(task) + chg.Set("initiated-by-snap", "test-snap") + + if lastAccessedTime != nil { + st.Cache(ctlcmd.ChangeRateLimitKey{ChangeID: chg.ID()}, lastAccessedTime) + } + + task.SetStatus(taskStatus) + + setup := &hookstate.HookSetup{Snap: "test-snap", Revision: snap.R(1), Hook: "install"} + ctx, err := hookstate.NewContext(task, st, setup, s.mockHandler, "") + c.Assert(err, IsNil) + + return ctx, chg.ID() +} + +// TestIsReadyMissingLastAccessed verifies that is-ready treats a missing +// last-accessed cache entry (e.g. after a snapd restart) as a first access and +// proceeds to report the real change status rather than returning an error. +func (s *isReadySuite) TestIsReadyMissingLastAccessed(c *C) { + ctx, changeID := s.rateLimitSetup(c, state.DoneStatus, nil) + + _, _, err := ctlcmd.Run(ctx, []string{"is-ready", changeID}, 0, nil) + + c.Assert(err, IsNil) +} + +// TestIsReadyRateLimitDelaysPolling verifies that when a snap polls within the +// 200 ms debounce window, is-ready sleeps for the remaining window duration +// before checking the change status. +func (s *isReadySuite) TestIsReadyRateLimitDelaysPolling(c *C) { + // A last-accessed time in the future guarantees we are within the debounce + // window, ensuring timeAfter is called with a positive duration. + ctx, changeID := s.rateLimitSetup(c, state.DoneStatus, time.Now().Add(time.Second).UnixNano()) + + var waitedFor time.Duration + restore := ctlcmd.MockTimeAfter(func(d time.Duration) <-chan time.Time { + waitedFor = d + return make(chan time.Time) // never fires; chg.Ready() wins + }) + defer restore() + + _, _, err := ctlcmd.Run(ctx, []string{"is-ready", changeID}, 0, nil) + + c.Assert(err, IsNil) + c.Check(waitedFor > 0, Equals, true) +} + +// TestIsReadyRateLimitTimerFires verifies that when timeAfter fires before the +// change is ready, is-ready reports DoingStatus (exit code 1) and the timer +// channel is drained. +func (s *isReadySuite) TestIsReadyRateLimitTimerFires(c *C) { + // A last-accessed time in the future puts us inside the debounce window. + // The task is left in DoingStatus so chg.Ready() never fires, ensuring + // the timer case is the only one that can win the select. + ctx, changeID := s.rateLimitSetup(c, state.DoingStatus, time.Now().Add(time.Second).UnixNano()) + + timerCh := make(chan time.Time, 1) + timerCh <- time.Now() // pre-fill so the timer fires immediately + restore := ctlcmd.MockTimeAfter(func(d time.Duration) <-chan time.Time { + return timerCh + }) + defer restore() + + _, _, err := ctlcmd.Run(ctx, []string{"is-ready", changeID}, 0, nil) + + c.Assert(err, DeepEquals, &ctlcmd.UnsuccessfulError{ExitCode: 1}) + c.Check(len(timerCh), Equals, 0) // element was consumed by the select +} + +// TestIsReadyExpiredWindowSkipsTimeAfter verifies that when the debounce window +// has already elapsed, is-ready returns the change status directly +func (s *isReadySuite) TestIsReadyExpiredWindowSkipsTimeAfter(c *C) { + // A last-accessed time sufficiently in the past guarantees toWait <= 0. + ctx, changeID := s.rateLimitSetup(c, state.DoneStatus, time.Now().Add(-time.Second).UnixNano()) + + called := false + restore := ctlcmd.MockTimeAfter(func(d time.Duration) <-chan time.Time { + called = true + return make(chan time.Time) + }) + defer restore() + + _, _, err := ctlcmd.Run(ctx, []string{"is-ready", changeID}, 0, nil) + + c.Assert(err, IsNil) + c.Check(called, Equals, false) +} diff --git a/overlord/hookstate/ctlcmd/set.go b/overlord/hookstate/ctlcmd/set.go index e359e48f8b6..d29ab8dcd1f 100644 --- a/overlord/hookstate/ctlcmd/set.go +++ b/overlord/hookstate/ctlcmd/set.go @@ -36,7 +36,7 @@ import ( "github.com/snapcore/snapd/snap" ) -var confdbstateTransactionForSet = confdbstate.GetTransactionToSet +var confdbstateWriteConfdb = confdbstate.WriteConfdbFromSnap type setCommand struct { baseCommand @@ -244,7 +244,7 @@ func (s *setCommand) setInterfaceSetting(context *hookstate.Context, plugOrSlot return nil } -func setConfdbValues(ctx *hookstate.Context, plugName string, requests map[string]any) error { +func setConfdbValues(ctx *hookstate.Context, plugName string, values map[string]any) error { ctx.Lock() defer ctx.Unlock() @@ -267,28 +267,6 @@ func setConfdbValues(ctx *hookstate.Context, plugName string, requests map[strin return fmt.Errorf("cannot modify confdb in %q hook", ctx.HookName()) } - tx, commitTxFunc, err := confdbstateTransactionForSet(ctx, ctx.State(), view) - if err != nil { - return err - } - - err = confdbstate.SetViaView(tx, view, requests) - if err != nil { - return err - } - - // if a new transaction was created, commit it - if commitTxFunc != nil { - _, waitChan, err := commitTxFunc() - if err != nil { - return err - } - - // wait for the transaction to be committed - ctx.Unlock() - <-waitChan - ctx.Lock() - } - - return nil + // TODO: add --wait-for timeout to options and cache in hookstate context + return confdbstateWriteConfdb(ctx, view, values) } diff --git a/overlord/hookstate/ctlcmd/set_test.go b/overlord/hookstate/ctlcmd/set_test.go index 2ac846b60ab..d73689771d9 100644 --- a/overlord/hookstate/ctlcmd/set_test.go +++ b/overlord/hookstate/ctlcmd/set_test.go @@ -28,14 +28,12 @@ import ( "github.com/snapcore/snapd/confdb" "github.com/snapcore/snapd/interfaces" - "github.com/snapcore/snapd/overlord/confdbstate" "github.com/snapcore/snapd/overlord/configstate/config" "github.com/snapcore/snapd/overlord/hookstate" "github.com/snapcore/snapd/overlord/hookstate/ctlcmd" "github.com/snapcore/snapd/overlord/hookstate/hooktest" "github.com/snapcore/snapd/overlord/state" "github.com/snapcore/snapd/snap" - "github.com/snapcore/snapd/testutil" ) type setSuite struct { @@ -412,44 +410,14 @@ func parsePath(c *C, path string) []confdb.Accessor { return accs } -func (s *confdbSuite) TestConfdbSetSingleView(c *C) { - s.state.Lock() - tx, err := confdbstate.NewTransaction(s.state, s.devAccID, "network") - s.state.Unlock() - c.Assert(err, IsNil) - - restore := ctlcmd.MockConfdbstateTransactionForSet(func(*hookstate.Context, *state.State, *confdb.View) (*confdbstate.Transaction, confdbstate.CommitTxFunc, error) { - return tx, nil, nil - }) - defer restore() - - stdout, stderr, err := ctlcmd.Run(s.mockContext, []string{"set", "--view", ":write-wifi", "ssid=other-ssid"}, 0, nil) - c.Assert(err, IsNil) - c.Check(stdout, IsNil) - c.Check(stderr, IsNil) - s.mockContext.Lock() - c.Assert(s.mockContext.Done(), IsNil) - s.mockContext.Unlock() - - val, err := tx.Get(parsePath(c, "wifi.ssid"), nil) - c.Assert(err, IsNil) - c.Assert(val, DeepEquals, "other-ssid") -} - func (s *confdbSuite) TestConfdbSetSingleViewNewTransaction(c *C) { - s.state.Lock() - tx, err := confdbstate.NewTransaction(s.state, s.devAccID, "network") - s.state.Unlock() - c.Assert(err, IsNil) - var called bool - restore := ctlcmd.MockConfdbstateTransactionForSet(func(*hookstate.Context, *state.State, *confdb.View) (*confdbstate.Transaction, confdbstate.CommitTxFunc, error) { - return tx, func() (string, <-chan struct{}, error) { - called = true - waitChan := make(chan struct{}) - close(waitChan) - return "123", waitChan, nil - }, nil + restore := ctlcmd.MockConfdbstateWriteConfdb(func(_ *hookstate.Context, _ *confdb.View, values map[string]any) error { + called = true + c.Assert(values, DeepEquals, map[string]any{ + "ssid": "other-ssid", + }) + return nil }) defer restore() @@ -457,22 +425,16 @@ func (s *confdbSuite) TestConfdbSetSingleViewNewTransaction(c *C) { c.Assert(err, IsNil) c.Check(stdout, IsNil) c.Check(stderr, IsNil) - c.Assert(called, Equals, true) - - val, err := tx.Get(parsePath(c, "wifi.ssid"), nil) - c.Assert(err, IsNil) - c.Assert(val, DeepEquals, "other-ssid") } func (s *confdbSuite) TestConfdbSetManyViews(c *C) { - s.state.Lock() - tx, err := confdbstate.NewTransaction(s.state, s.devAccID, "network") - s.state.Unlock() - c.Assert(err, IsNil) - - restore := ctlcmd.MockConfdbstateTransactionForSet(func(*hookstate.Context, *state.State, *confdb.View) (*confdbstate.Transaction, confdbstate.CommitTxFunc, error) { - return tx, nil, nil + restore := ctlcmd.MockConfdbstateWriteConfdb(func(_ *hookstate.Context, _ *confdb.View, values map[string]any) error { + c.Assert(values, DeepEquals, map[string]any{ + "ssid": "other-ssid", + "password": "other-secret", + }) + return nil }) defer restore() @@ -480,14 +442,6 @@ func (s *confdbSuite) TestConfdbSetManyViews(c *C) { c.Assert(err, IsNil) c.Check(stdout, IsNil) c.Check(stderr, IsNil) - - val, err := tx.Get(parsePath(c, "wifi.ssid"), nil) - c.Assert(err, IsNil) - c.Assert(val, Equals, "other-ssid") - - val, err = tx.Get(parsePath(c, "wifi.psk"), nil) - c.Assert(err, IsNil) - c.Assert(val, Equals, "other-secret") } func (s *confdbSuite) TestConfdbSetInvalid(c *C) { @@ -516,19 +470,9 @@ func (s *confdbSuite) TestConfdbSetInvalid(c *C) { } func (s *confdbSuite) TestConfdbSetExclamationMark(c *C) { - s.state.Lock() - tx, err := confdbstate.NewTransaction(s.state, s.devAccID, "network") - s.state.Unlock() - c.Assert(err, IsNil) - - err = tx.Set(parsePath(c, "wifi.ssid"), "foo") - c.Assert(err, IsNil) - - err = tx.Set(parsePath(c, "wifi.psk"), "bar") - c.Assert(err, IsNil) - - restore := ctlcmd.MockConfdbstateTransactionForSet(func(*hookstate.Context, *state.State, *confdb.View) (*confdbstate.Transaction, confdbstate.CommitTxFunc, error) { - return tx, nil, nil + restore := ctlcmd.MockConfdbstateWriteConfdb(func(_ *hookstate.Context, _ *confdb.View, values map[string]any) error { + c.Assert(values, DeepEquals, map[string]any{"password": nil}) + return nil }) defer restore() @@ -536,24 +480,15 @@ func (s *confdbSuite) TestConfdbSetExclamationMark(c *C) { c.Assert(err, IsNil) c.Check(stdout, IsNil) c.Check(stderr, IsNil) - - _, err = tx.Get(parsePath(c, "wifi.psk"), nil) - c.Assert(err, testutil.ErrorIs, &confdb.NoDataError{}) - - val, err := tx.Get(parsePath(c, "wifi.ssid"), nil) - c.Assert(err, IsNil) - c.Assert(val, Equals, "foo") } func (s *confdbSuite) TestConfdbModifyHooks(c *C) { s.state.Lock() defer s.state.Unlock() - tx, err := confdbstate.NewTransaction(s.state, s.devAccID, "network") - c.Assert(err, IsNil) - - restore := ctlcmd.MockConfdbstateTransactionForSet(func(*hookstate.Context, *state.State, *confdb.View) (*confdbstate.Transaction, confdbstate.CommitTxFunc, error) { - return tx, nil, nil + restore := ctlcmd.MockConfdbstateWriteConfdb(func(_ *hookstate.Context, _ *confdb.View, values map[string]any) error { + c.Assert(values, DeepEquals, map[string]any{"password": "thing"}) + return nil }) defer restore() diff --git a/overlord/hookstate/ctlcmd/unset_test.go b/overlord/hookstate/ctlcmd/unset_test.go index 884a0007128..53f18bfaf50 100644 --- a/overlord/hookstate/ctlcmd/unset_test.go +++ b/overlord/hookstate/ctlcmd/unset_test.go @@ -25,14 +25,12 @@ import ( . "gopkg.in/check.v1" "github.com/snapcore/snapd/confdb" - "github.com/snapcore/snapd/overlord/confdbstate" "github.com/snapcore/snapd/overlord/configstate/config" "github.com/snapcore/snapd/overlord/hookstate" "github.com/snapcore/snapd/overlord/hookstate/ctlcmd" "github.com/snapcore/snapd/overlord/hookstate/hooktest" "github.com/snapcore/snapd/overlord/state" "github.com/snapcore/snapd/snap" - "github.com/snapcore/snapd/testutil" ) type unsetSuite struct { @@ -165,31 +163,18 @@ func (s *unsetSuite) TestCommandWithoutContext(c *C) { } func (s *confdbSuite) TestConfdbUnsetManyViews(c *C) { - s.state.Lock() - tx, err := confdbstate.NewTransaction(s.state, s.devAccID, "network") - s.state.Unlock() - c.Assert(err, IsNil) - - err = tx.Set(parsePath(c, "wifi.ssid"), "foo") - c.Assert(err, IsNil) - - err = tx.Set(parsePath(c, "wifi.psk"), "bar") - c.Assert(err, IsNil) - - ctlcmd.MockConfdbstateTransactionForSet(func(*hookstate.Context, *state.State, *confdb.View) (*confdbstate.Transaction, confdbstate.CommitTxFunc, error) { - return tx, nil, nil + ctlcmd.MockConfdbstateWriteConfdb(func(_ *hookstate.Context, _ *confdb.View, values map[string]any) error { + c.Assert(values, DeepEquals, map[string]any{ + "ssid": nil, + "password": nil, + }) + return nil }) stdout, stderr, err := ctlcmd.Run(s.mockContext, []string{"unset", "--view", ":write-wifi", "ssid", "password"}, 0, nil) c.Assert(err, IsNil) c.Check(stdout, IsNil) c.Check(stderr, IsNil) - - _, err = tx.Get(parsePath(c, "wifi.ssid"), nil) - c.Assert(err, testutil.ErrorIs, &confdb.NoDataError{}) - - _, err = tx.Get(parsePath(c, "wifi.psk"), nil) - c.Assert(err, testutil.ErrorIs, &confdb.NoDataError{}) } func (s *confdbSuite) TestConfdbUnsetInvalid(c *C) { diff --git a/overlord/snapstate/check_snap.go b/overlord/snapstate/check_snap.go index 1833d2b974e..48e2db07bd4 100644 --- a/overlord/snapstate/check_snap.go +++ b/overlord/snapstate/check_snap.go @@ -144,7 +144,12 @@ func validateInfoAndFlags(info *snap.Info, snapst *SnapState, flags Flags) error // check assumes err := naming.ValidateAssumes(info.Assumes, snapdtool.Version, featureSet, arch.DpkgArchitecture()) if err != nil { - return fmt.Errorf("snap %q assumes %w (try to refresh snapd)", info.InstanceName(), err) + askToRefreshSnapd := " (try to refresh snapd)" + isaErr := &naming.IsaError{} + if errors.As(err, &isaErr) { + askToRefreshSnapd = "" + } + return fmt.Errorf("snap %q assumes %w%s", info.InstanceName(), err, askToRefreshSnapd) } // check and create system-usernames diff --git a/overlord/snapstate/check_snap_test.go b/overlord/snapstate/check_snap_test.go index b0406619509..c445322f8d6 100644 --- a/overlord/snapstate/check_snap_test.go +++ b/overlord/snapstate/check_snap_test.go @@ -27,6 +27,7 @@ import ( . "gopkg.in/check.v1" "github.com/snapcore/snapd/arch" + "github.com/snapcore/snapd/arch/archtest" "github.com/snapcore/snapd/asserts" "github.com/snapcore/snapd/dirs" "github.com/snapcore/snapd/osutil" @@ -95,6 +96,8 @@ architectures: } func (s *checkSnapSuite) TestCheckSnapAssumes(c *C) { + s.AddCleanup(archtest.MockArchitecture("arm64")) + var assumesTests = []struct { version string assumes string @@ -109,6 +112,9 @@ func (s *checkSnapSuite) TestCheckSnapAssumes(c *C) { assumes: "[f1, f2]", classic: true, error: `snap "foo" assumes unsupported features: f1, f2 \(try to refresh snapd\)`, + }, { + assumes: "[isa-arm64-someisa]", + error: `snap "foo" assumes isa-arm64-someisa: ISA specification is not supported for arch: arm64`, }, } diff --git a/seclog/audit_linux.go b/seclog/audit_linux.go new file mode 100644 index 00000000000..ac66ac66ca8 --- /dev/null +++ b/seclog/audit_linux.go @@ -0,0 +1,197 @@ +// -*- Mode: Go; indent-tabs-mode: t -*- + +// go1.21 is required for binary.NativeEndian which is used to serialize +// netlink headers in host byte order. NativeEndian is supported on all +// architectures snapd targets: amd64, arm, arm64, ppc64le, riscv64. +// See https://cs.opensource.google/go/go/+/refs/tags/go1.26.2:src/encoding/binary/native_endian_little.go +// The nonativeendian tag allows excluding this file on toolchains that +// lack NativeEndian support. +// See https://go.dev/doc/go1.21#encoding/binary +//go:build go1.21 && !nonativeendian + +/* + * Copyright (C) 2026 Canonical Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 3 as + * published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +package seclog + +import ( + "encoding/binary" + "errors" + "fmt" + "io" + "sync/atomic" + "syscall" +) + +const ( + // AUDIT_TRUSTED_APP is the audit message type for trusted application messages. + // See https://github.com/linux-audit/audit-userspace/blob/master/lib/audit-records.h + auditTrustedApp = 1121 + + // NETLINK_AUDIT is the netlink protocol for audit. + // See https://github.com/torvalds/linux/blob/master/include/uapi/linux/netlink.h + netlinkAudit = 9 +) + +// netlinkOps abstracts the syscall operations needed to open, bind, query, +// send to, and close a netlink socket. Production code uses [realNetlinkOps]; +// tests can substitute a recording or stubbing implementation. +type netlinkOps interface { + Socket(domain, typ, proto int) (int, error) + Bind(fd int, sa syscall.Sockaddr) error + Getsockname(fd int) (syscall.Sockaddr, error) + Sendto(fd int, p []byte, flags int, to syscall.Sockaddr) error + Close(fd int) error +} + +// realNetlinkOps delegates every operation to the corresponding syscall. +type realNetlinkOps struct{} + +func (realNetlinkOps) Socket(domain, typ, proto int) (int, error) { + return syscall.Socket(domain, typ, proto) +} + +func (realNetlinkOps) Bind(fd int, sa syscall.Sockaddr) error { + return syscall.Bind(fd, sa) +} + +func (realNetlinkOps) Getsockname(fd int) (syscall.Sockaddr, error) { + return syscall.Getsockname(fd) +} + +func (realNetlinkOps) Sendto(fd int, p []byte, flags int, to syscall.Sockaddr) error { + return syscall.Sendto(fd, p, flags, to) +} + +func (realNetlinkOps) Close(fd int) error { + return syscall.Close(fd) +} + +var netlink netlinkOps = realNetlinkOps{} + +func init() { + registerSink(SinkAudit, auditSinkFactory{}) +} + +// auditSinkFactory implements [sinkFactory] for the kernel audit sink. +type auditSinkFactory struct{} + +// Ensure [auditSinkFactory] implements [sinkFactory]. +var _ sinkFactory = auditSinkFactory{} + +// Open opens a netlink audit socket and returns an [auditWriter] +// that sends each written payload as an AUDIT_TRUSTED_APP. The appID is +// currently unused but accepted for sink factory compatibility. +func (auditSinkFactory) Open(_ string) (io.Writer, error) { + // SOCK_CLOEXEC prevents the fd from leaking to child processes. + fd, err := netlink.Socket(syscall.AF_NETLINK, syscall.SOCK_RAW|syscall.SOCK_CLOEXEC, netlinkAudit) + if err != nil { + return nil, fmt.Errorf("cannot open audit socket: %w", err) + } + addr := &syscall.SockaddrNetlink{ + Family: syscall.AF_NETLINK, + Pid: 0, // let kernel assign port ID + Groups: 0, + } + if err := netlink.Bind(fd, addr); err != nil { + netlink.Close(fd) + return nil, fmt.Errorf("cannot bind audit socket: %w", err) + } + portID, err := getPortID(fd) + if err != nil { + netlink.Close(fd) + return nil, fmt.Errorf("cannot get audit socket port ID: %w", err) + } + return &auditWriter{fd: fd, portID: portID}, nil +} + +// getPortID returns the kernel-assigned port ID of the netlink socket. +// When binding with Pid 0, the kernel assigns a unique port ID that may +// or may not equal the process PID. This value must be used in outgoing +// netlink message headers. +func getPortID(fd int) (uint32, error) { + sa, err := netlink.Getsockname(fd) + if err != nil { + return 0, err + } + addr, ok := sa.(*syscall.SockaddrNetlink) + if !ok { + return 0, errors.New("unexpected socket address type") + } + return addr.Pid, nil +} + +// auditWriter sends messages to the kernel audit subsystem via a netlink +// socket. Each Write call sends the payload as an AUDIT_TRUSTED_APP. +// +// The writer is safe for sequential use; concurrent use requires external +// synchronization. +type auditWriter struct { + fd int + portID uint32 + seq atomic.Uint32 +} + +// Write sends p as the payload of an AUDIT_TRUSTED_APP netlink message. +// The returned byte count reflects only the original payload length. +func (aw *auditWriter) Write(payload []byte) (int, error) { + msg := aw.buildMessage(payload) + addr := &syscall.SockaddrNetlink{ + Family: syscall.AF_NETLINK, + Pid: 0, // kernel + } + if err := netlink.Sendto(aw.fd, msg, 0, addr); err != nil { + return 0, fmt.Errorf("cannot send audit message: %w", err) + } + return len(payload), nil +} + +// Close closes the underlying netlink socket. +func (aw *auditWriter) Close() error { + return netlink.Close(aw.fd) +} + +// nlmsghdrSize is the size of a netlink message header in bytes +// (uint32 + uint16 + uint16 + uint32 + uint32 = 16). +const nlmsghdrSize = 16 + +// buildMessage constructs a raw netlink AUDIT_TRUSTED_APP containing payload. +// The header layout follows struct nlmsghdr from +// https://github.com/torvalds/linux/blob/master/include/uapi/linux/netlink.h#L45 +func (aw *auditWriter) buildMessage(payload []byte) []byte { + totalLen := nlmsghdrSize + uint32(len(payload)) + buf := make([]byte, nlmsgAlign(totalLen)) + + // Write header in native byte order (netlink uses host endianness). + // NativeEndian is supported on all architectures snapd targets: + // amd64, arm, arm64, ppc64le, riscv64. + // See https://cs.opensource.google/go/go/+/refs/tags/go1.26.2:src/encoding/binary/native_endian_little.go + binary.NativeEndian.PutUint32(buf[0:4], totalLen) + binary.NativeEndian.PutUint16(buf[4:6], auditTrustedApp) + binary.NativeEndian.PutUint16(buf[6:8], syscall.NLM_F_REQUEST) // fire-and-forget, no ACK + binary.NativeEndian.PutUint32(buf[8:12], aw.seq.Add(1)) + binary.NativeEndian.PutUint32(buf[12:16], aw.portID) + + // Write payload. + copy(buf[nlmsghdrSize:], payload) + return buf +} + +// nlmsgAlign rounds up to the nearest 4-byte boundary per NLMSG_ALIGN. +func nlmsgAlign(size uint32) uint32 { + return (size + 3) &^ 3 +} diff --git a/seclog/audit_linux_test.go b/seclog/audit_linux_test.go new file mode 100644 index 00000000000..35b43c8a93b --- /dev/null +++ b/seclog/audit_linux_test.go @@ -0,0 +1,320 @@ +// -*- Mode: Go; indent-tabs-mode: t -*- +//go:build go1.21 && !nonativeendian + +/* + * Copyright (C) 2026 Canonical Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 3 as + * published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +package seclog_test + +import ( + "encoding/binary" + "fmt" + "slices" + "syscall" + + . "gopkg.in/check.v1" + + "github.com/snapcore/snapd/seclog" +) + +type AuditSuite struct{} + +var _ = Suite(&AuditSuite{}) + +func (s *AuditSuite) TestNlmsgAlignAlreadyAligned(c *C) { + c.Check(seclog.NlmsgAlign(0), Equals, uint32(0)) + c.Check(seclog.NlmsgAlign(4), Equals, uint32(4)) + c.Check(seclog.NlmsgAlign(8), Equals, uint32(8)) + c.Check(seclog.NlmsgAlign(16), Equals, uint32(16)) +} + +func (s *AuditSuite) TestNlmsgAlignRoundsUp(c *C) { + c.Check(seclog.NlmsgAlign(1), Equals, uint32(4)) + c.Check(seclog.NlmsgAlign(2), Equals, uint32(4)) + c.Check(seclog.NlmsgAlign(3), Equals, uint32(4)) + c.Check(seclog.NlmsgAlign(5), Equals, uint32(8)) + c.Check(seclog.NlmsgAlign(17), Equals, uint32(20)) +} + +func (s *AuditSuite) TestBuildMessageHeaderLayout(c *C) { + aw := &seclog.AuditWriter{} + + payload := []byte("hello") + msg := seclog.AuditWriterBuildMessage(aw, payload) + + // Total length: 16 (header) + 5 (payload) = 21, aligned to 24. + c.Assert(len(msg), Equals, 24) + + // nlmsghdr fields in native byte order. + totalLen := binary.NativeEndian.Uint32(msg[0:4]) + c.Check(totalLen, Equals, uint32(21)) + + msgType := binary.NativeEndian.Uint16(msg[4:6]) + c.Check(msgType, Equals, uint16(seclog.AuditTrustedApp)) + + flags := binary.NativeEndian.Uint16(msg[6:8]) + c.Check(flags, Equals, uint16(syscall.NLM_F_REQUEST)) + + seq := binary.NativeEndian.Uint32(msg[8:12]) + c.Check(seq, Equals, uint32(1)) + + portID := binary.NativeEndian.Uint32(msg[12:16]) + c.Check(portID, Equals, uint32(0)) + + // Payload follows header. + c.Check(string(msg[seclog.NlmsghdrSize:seclog.NlmsghdrSize+5]), Equals, "hello") + + // Padding bytes after payload should be zero. + c.Check(msg[21], Equals, byte(0)) + c.Check(msg[22], Equals, byte(0)) + c.Check(msg[23], Equals, byte(0)) +} + +func (s *AuditSuite) TestBuildMessagePortID(c *C) { + aw := &seclog.AuditWriter{} + seclog.AuditWriterSetPortID(aw, 42) + + msg := seclog.AuditWriterBuildMessage(aw, []byte("x")) + + portID := binary.NativeEndian.Uint32(msg[12:16]) + c.Check(portID, Equals, uint32(42)) +} + +func (s *AuditSuite) TestBuildMessageSequenceIncrements(c *C) { + aw := &seclog.AuditWriter{} + + msg1 := seclog.AuditWriterBuildMessage(aw, []byte("a")) + msg2 := seclog.AuditWriterBuildMessage(aw, []byte("b")) + msg3 := seclog.AuditWriterBuildMessage(aw, []byte("c")) + + seq1 := binary.NativeEndian.Uint32(msg1[8:12]) + seq2 := binary.NativeEndian.Uint32(msg2[8:12]) + seq3 := binary.NativeEndian.Uint32(msg3[8:12]) + + c.Check(seq1, Equals, uint32(1)) + c.Check(seq2, Equals, uint32(2)) + c.Check(seq3, Equals, uint32(3)) +} + +func (s *AuditSuite) TestBuildMessageAlignedPayload(c *C) { + aw := &seclog.AuditWriter{} + + // Payload of exactly 4 bytes: total = 20 which is already aligned. + msg := seclog.AuditWriterBuildMessage(aw, []byte("abcd")) + c.Check(len(msg), Equals, 20) + + totalLen := binary.NativeEndian.Uint32(msg[0:4]) + c.Check(totalLen, Equals, uint32(20)) +} + +func (s *AuditSuite) TestBuildMessageEmptyPayload(c *C) { + aw := &seclog.AuditWriter{} + + msg := seclog.AuditWriterBuildMessage(aw, []byte{}) + + // 16-byte header, already aligned. + c.Check(len(msg), Equals, 16) + + totalLen := binary.NativeEndian.Uint32(msg[0:4]) + c.Check(totalLen, Equals, uint32(16)) +} + +func (s *AuditSuite) TestNlmsghdrSizeConstant(c *C) { + // nlmsghdr is: uint32 + uint16 + uint16 + uint32 + uint32 = 16 + c.Check(seclog.NlmsghdrSize, Equals, 16) +} + +func (s *AuditSuite) TestAuditSinkRegistered(c *C) { + // The init() in audit_linux.go registers SinkAudit. + // Setup should not fail with "unknown sink" for SinkAudit. + // We verify indirectly: if the sink were missing, Setup would + // return "unknown sink". + restore := seclog.MockImplementations(map[seclog.Impl]seclog.ImplFactory{}) + defer restore() + + err := seclog.Setup(seclog.ImplSlog, seclog.SinkAudit, "test", seclog.LevelInfo) + // This should fail with "unknown implementation" (not "unknown sink"), + // proving the audit sink is registered. + c.Check(err, ErrorMatches, `cannot set up security logger: unknown implementation "slog"`) +} + +// mockNetlinkOps records calls and returns configurable results. +type mockNetlinkOps struct { + socketFD int + socketErr error + bindErr error + getsockname syscall.Sockaddr + getsocknErr error + sendtoData []byte + sendtoErr error + closedFDs []int + closeErr error +} + +func (m *mockNetlinkOps) Socket(domain, typ, proto int) (int, error) { + return m.socketFD, m.socketErr +} + +func (m *mockNetlinkOps) Bind(fd int, sa syscall.Sockaddr) error { + return m.bindErr +} + +func (m *mockNetlinkOps) Getsockname(fd int) (syscall.Sockaddr, error) { + return m.getsockname, m.getsocknErr +} + +func (m *mockNetlinkOps) Sendto(fd int, p []byte, flags int, to syscall.Sockaddr) error { + m.sendtoData = slices.Clone(p) + return m.sendtoErr +} + +func (m *mockNetlinkOps) Close(fd int) error { + m.closedFDs = append(m.closedFDs, fd) + return m.closeErr +} + +// Ensure mockNetlinkOps satisfies the interface. +var _ seclog.NetlinkOps = (*mockNetlinkOps)(nil) + +func (s *AuditSuite) TestOpenSuccess(c *C) { + mock := &mockNetlinkOps{ + socketFD: 42, + getsockname: &syscall.SockaddrNetlink{ + Family: syscall.AF_NETLINK, + Pid: 99, + }, + } + restore := seclog.MockNetlink(mock) + defer restore() + + writer, err := seclog.AuditSinkFactory{}.Open("test") + c.Assert(err, IsNil) + c.Assert(writer, NotNil) +} + +func (s *AuditSuite) TestOpenSocketError(c *C) { + mock := &mockNetlinkOps{ + socketErr: fmt.Errorf("permission denied"), + } + restore := seclog.MockNetlink(mock) + defer restore() + + _, err := seclog.AuditSinkFactory{}.Open("test") + c.Assert(err, ErrorMatches, "cannot open audit socket: permission denied") +} + +func (s *AuditSuite) TestOpenBindError(c *C) { + mock := &mockNetlinkOps{ + socketFD: 10, + bindErr: fmt.Errorf("address in use"), + } + restore := seclog.MockNetlink(mock) + defer restore() + + _, err := seclog.AuditSinkFactory{}.Open("test") + c.Assert(err, ErrorMatches, "cannot bind audit socket: address in use") + // Socket should have been closed on bind failure. + c.Check(mock.closedFDs, DeepEquals, []int{10}) +} + +func (s *AuditSuite) TestOpenGetsocknameError(c *C) { + mock := &mockNetlinkOps{ + socketFD: 10, + getsocknErr: fmt.Errorf("bad fd"), + } + restore := seclog.MockNetlink(mock) + defer restore() + + _, err := seclog.AuditSinkFactory{}.Open("test") + c.Assert(err, ErrorMatches, "cannot get audit socket port ID: bad fd") + c.Check(mock.closedFDs, DeepEquals, []int{10}) +} + +func (s *AuditSuite) TestOpenGetsocknameWrongAddressType(c *C) { + mock := &mockNetlinkOps{ + socketFD: 10, + // Return a non-netlink address type. + getsockname: &syscall.SockaddrUnix{Name: "/tmp/sock"}, + } + restore := seclog.MockNetlink(mock) + defer restore() + + _, err := seclog.AuditSinkFactory{}.Open("test") + c.Assert(err, ErrorMatches, "cannot get audit socket port ID: unexpected socket address type") + c.Check(mock.closedFDs, DeepEquals, []int{10}) +} + +func (s *AuditSuite) TestWriteSendtoError(c *C) { + mock := &mockNetlinkOps{ + socketFD: 7, + getsockname: &syscall.SockaddrNetlink{ + Family: syscall.AF_NETLINK, + Pid: 1, + }, + sendtoErr: fmt.Errorf("no buffer space"), + } + restore := seclog.MockNetlink(mock) + defer restore() + + writer, err := seclog.AuditSinkFactory{}.Open("test") + c.Assert(err, IsNil) + + _, err = writer.Write([]byte("test")) + c.Assert(err, ErrorMatches, "cannot send audit message: no buffer space") +} + +func (s *AuditSuite) TestWriteSuccess(c *C) { + mock := &mockNetlinkOps{ + socketFD: 7, + getsockname: &syscall.SockaddrNetlink{ + Family: syscall.AF_NETLINK, + Pid: 1, + }, + } + restore := seclog.MockNetlink(mock) + defer restore() + + writer, err := seclog.AuditSinkFactory{}.Open("test") + c.Assert(err, IsNil) + + n, err := writer.Write([]byte("hello")) + c.Assert(err, IsNil) + c.Check(n, Equals, 5) + // The mock captured the raw netlink message. + c.Check(len(mock.sendtoData) > seclog.NlmsghdrSize, Equals, true) +} + +func (s *AuditSuite) TestClose(c *C) { + mock := &mockNetlinkOps{ + socketFD: 7, + getsockname: &syscall.SockaddrNetlink{ + Family: syscall.AF_NETLINK, + Pid: 1, + }, + } + restore := seclog.MockNetlink(mock) + defer restore() + + writer, err := seclog.AuditSinkFactory{}.Open("test") + c.Assert(err, IsNil) + + closer, ok := writer.(interface{ Close() error }) + c.Assert(ok, Equals, true) + err = closer.Close() + c.Assert(err, IsNil) + c.Check(mock.closedFDs, DeepEquals, []int{7}) +} diff --git a/seclog/export_audit_linux_test.go b/seclog/export_audit_linux_test.go new file mode 100644 index 00000000000..891ec26061e --- /dev/null +++ b/seclog/export_audit_linux_test.go @@ -0,0 +1,49 @@ +// -*- Mode: Go; indent-tabs-mode: t -*- +//go:build go1.21 && !nonativeendian + +/* + * Copyright (C) 2026 Canonical Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 3 as + * published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +package seclog + +import ( + "github.com/snapcore/snapd/testutil" +) + +type AuditWriter = auditWriter + +type AuditSinkFactory = auditSinkFactory + +type NetlinkOps = netlinkOps + +var NlmsgAlign = nlmsgAlign + +const NlmsghdrSize = nlmsghdrSize + +const AuditTrustedApp = auditTrustedApp + +func AuditWriterBuildMessage(aw *auditWriter, payload []byte) []byte { + return aw.buildMessage(payload) +} + +func AuditWriterSetPortID(aw *auditWriter, id uint32) { + aw.portID = id +} + +func MockNetlink(ops netlinkOps) (restore func()) { + return testutil.Mock(&netlink, ops) +} diff --git a/seclog/export_slog_test.go b/seclog/export_slog_test.go new file mode 100644 index 00000000000..f893af128e0 --- /dev/null +++ b/seclog/export_slog_test.go @@ -0,0 +1,27 @@ +// -*- Mode: Go; indent-tabs-mode: t -*- +//go:build go1.21 && !noslog + +/* + * Copyright (C) 2026 Canonical Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 3 as + * published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +package seclog + +type ( + SlogImplementation = slogImplementation + SlogLogger = slogLogger + LevelWriter = levelWriter +) diff --git a/seclog/export_test.go b/seclog/export_test.go new file mode 100644 index 00000000000..b6cf469437f --- /dev/null +++ b/seclog/export_test.go @@ -0,0 +1,116 @@ +// -*- Mode: Go; indent-tabs-mode: t -*- + +/* + * Copyright (C) 2026 Canonical Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 3 as + * published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +package seclog + +import ( + "io" + + "github.com/snapcore/snapd/testutil" +) + +var NewNopLogger = newNopLogger + +var RegisterImpl = registerImpl +var RegisterSink = registerSink + +type ( + ImplFactory = implFactory + SinkFactory = sinkFactory + SecurityLogger = securityLogger +) + +func MockSinks(m map[Sink]sinkFactory) (restore func()) { + restore = testutil.Backup(&sinks) + sinks = m + return restore +} + +// sinkFunc adapts a plain function to the [sinkFactory] interface. +type sinkFunc func(string) (io.Writer, error) + +// SinkFunc exports sinkFunc for use in external test packages. +type SinkFunc = sinkFunc + +func (f sinkFunc) Open(appID string) (io.Writer, error) { return f(appID) } + +// MockNewSink is a convenience wrapper that replaces the audit sink factory +// in the sinks map. +func MockNewSink(f func(string) (io.Writer, error)) (restore func()) { + restore = testutil.Backup(&sinks) + sinks = map[Sink]sinkFactory{ + SinkAudit: sinkFunc(f), + } + return restore +} + +func MockImplementations(m map[Impl]implFactory) (restore func()) { + restore = testutil.Backup(&implementations) + implementations = m + return restore +} + +func MockGlobalLogger(l securityLogger) (restore func()) { + restore = testutil.Backup(&globalLogger) + globalLogger = l + return restore +} + +func MockGlobalCloser(c io.Closer) (restore func()) { + restore = testutil.Backup(&globalCloser) + globalCloser = c + return restore +} + +// LoggerSetup is the exported alias for the unexported loggerSetup type, +// allowing tests to create and mock setup state. +type LoggerSetup = loggerSetup + +// NewLoggerSetup constructs a LoggerSetup for use in tests. +func NewLoggerSetup(impl Impl, sink Sink, appID string, minLevel Level) *LoggerSetup { + return &LoggerSetup{impl: impl, sink: sink, appID: appID, minLevel: minLevel} +} + +func MockGlobalSetup(s *LoggerSetup) (restore func()) { + restore = testutil.Backup(&globalSetup) + globalSetup = s + return restore +} + +const MaxWriteFailures = maxWriteFailures + +func MockWriteFailures(n int) (restore func()) { + restore = testutil.Backup(&writeFailures) + writeFailures = n + return restore +} + +func MockFailed(f bool) (restore func()) { + restore = testutil.Backup(&failed) + failed = f + return restore +} + +func GetFailed() bool { + return failed +} + +func GetWriteFailures() int { + return writeFailures +} diff --git a/seclog/nop.go b/seclog/nop.go new file mode 100644 index 00000000000..63c60986890 --- /dev/null +++ b/seclog/nop.go @@ -0,0 +1,46 @@ +// -*- Mode: Go; indent-tabs-mode: t -*- + +/* + * Copyright (C) 2026 Canonical Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 3 as + * published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +package seclog + +// nopLogger provides a no-operation [securityLogger] implementation. +type nopLogger struct{} + +// Ensure [nopLogger] implements [securityLogger]. +var _ securityLogger = (*nopLogger)(nil) + +func newNopLogger() securityLogger { + return nopLogger{} +} + +// LogLoggingEnabled implements [securityLogger.LogLoggingEnabled]. +func (nopLogger) LogLoggingEnabled() { +} + +// LogLoggingDisabled implements [securityLogger.LogLoggingDisabled]. +func (nopLogger) LogLoggingDisabled() { +} + +// LogLoginSuccess implements [securityLogger.LogLoginSuccess]. +func (nopLogger) LogLoginSuccess(user SnapdUser) { +} + +// LogLoginFailure implements [securityLogger.LogLoginFailure]. +func (nopLogger) LogLoginFailure(user SnapdUser, reason Reason) { +} diff --git a/seclog/nop_test.go b/seclog/nop_test.go new file mode 100644 index 00000000000..c479459238b --- /dev/null +++ b/seclog/nop_test.go @@ -0,0 +1,73 @@ +// -*- Mode: Go; indent-tabs-mode: t -*- + +/* + * Copyright (C) 2026 Canonical Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 3 as + * published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +package seclog_test + +import ( + . "gopkg.in/check.v1" + + "github.com/snapcore/snapd/seclog" + "github.com/snapcore/snapd/testutil" +) + +type NopSuite struct { + testutil.BaseTest +} + +var _ = Suite(&NopSuite{}) + +func (s *NopSuite) SetUpTest(c *C) { + s.BaseTest.SetUpTest(c) +} + +func (s *NopSuite) TearDownTest(c *C) { + s.BaseTest.TearDownTest(c) +} + +func (s *NopSuite) TestLogLoggingEnabled(c *C) { + logger := seclog.NewNopLogger() + c.Assert(logger, NotNil) + + // nop logger discards all messages without error + logger.LogLoggingEnabled() +} + +func (s *NopSuite) TestLogLoggingDisabled(c *C) { + logger := seclog.NewNopLogger() + c.Assert(logger, NotNil) + + // nop logger discards all messages without error + logger.LogLoggingDisabled() +} + +func (s *NopSuite) TestLogLoginSuccess(c *C) { + logger := seclog.NewNopLogger() + c.Assert(logger, NotNil) + + // nop logger discards all messages without error + logger.LogLoginSuccess(seclog.SnapdUser{StoreUserEmail: "user@gmail.com"}) +} + +func (s *NopSuite) TestLogLoginFailure(c *C) { + logger := seclog.NewNopLogger() + c.Assert(logger, NotNil) + + // nop logger discards all messages without error + logger.LogLoginFailure(seclog.SnapdUser{StoreUserEmail: "user@gmail.com"}, seclog.Reason{}) +} diff --git a/seclog/seclog.go b/seclog/seclog.go new file mode 100644 index 00000000000..8208fb0e67d --- /dev/null +++ b/seclog/seclog.go @@ -0,0 +1,439 @@ +// -*- Mode: Go; indent-tabs-mode: t -*- + +/* + * Copyright (C) 2026 Canonical Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 3 as + * published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +package seclog + +import ( + "fmt" + "io" + "sync" + "time" + + "github.com/snapcore/snapd/logger" +) + +// Level is the importance or severity of a log event. +// The higher the level, the more severe the event. +type Level int + +// Log levels. +const ( + LevelDebug Level = 1 + LevelInfo Level = 2 + LevelWarn Level = 3 + LevelError Level = 4 + LevelCritical Level = 5 +) + +// String returns a name for the level. +// If the level has a name, then that name +// in uppercase is returned. +// If the level is between named values, then +// an integer is appended to the uppercased name. +// Examples: +// +// LevelWarn.String() => "WARN" +// (LevelCritical+2).String() => "CRITICAL+2" +func (l Level) String() string { + str := func(base string, val Level) string { + if val == 0 { + return base + } + return fmt.Sprintf("%s%+d", base, val) + } + + switch { + case l < LevelInfo: + return str("DEBUG", l-LevelDebug) + case l < LevelWarn: + return str("INFO", l-LevelInfo) + case l < LevelError: + return str("WARN", l-LevelWarn) + case l < LevelCritical: + return str("ERROR", l-LevelError) + default: + return str("CRITICAL", l-LevelCritical) + } +} + +// Impl represents a known logger implementation identifier used for +// registration and selection of security loggers. +type Impl string + +// Logger implementations. +const ( + ImplSlog Impl = "slog" // slog based structured logger +) + +// Sink identifies a log output destination. +type Sink string + +// Sink types. +const ( + SinkAudit Sink = "audit" // kernel audit via netlink +) + +// SnapdUser represents the identity of a user for security log events. +// The slog output schema is defined by [SnapdUser.LogValue], which +// renders Expiration as "never" for zero values instead of emitting a +// zero-value datetime. +type SnapdUser struct { + ID int64 `json:"snapd-user-id"` + SystemUserName string `json:"system-user-name"` + StoreUserEmail string `json:"store-user-email"` + Expiration time.Time `json:"expiration"` +} + +// String returns a colon-separated description of the user in the form +// "::". Fields that are unset use +// "unknown" as a placeholder. A zero ID is treated as unset. +func (u SnapdUser) String() string { + const unknown = "unknown" + + id := unknown + if u.ID != 0 { + id = fmt.Sprintf("%d", u.ID) + } + + email := unknown + if u.StoreUserEmail != "" { + email = u.StoreUserEmail + } + + name := unknown + if u.SystemUserName != "" { + name = u.SystemUserName + } + + return id + ":" + email + ":" + name +} + +// Reason codes are stable identifiers for security audit events. +const ( + ReasonInvalidCredentials = "invalid-credentials" + ReasonTwoFactorRequired = "two-factor-required" + ReasonTwoFactorFailed = "two-factor-failed" + ReasonInvalidAuthData = "invalid-auth-data" + ReasonPasswordPolicy = "password-policy" + ReasonInternal = "internal" +) + +// Reason describes why a security event happened. +type Reason struct { + Code string `json:"code"` + Message string `json:"message"` +} + +// String returns a colon-separated representation in the form +// ":". Fields that are unset use "unknown" as a +// placeholder. +func (r Reason) String() string { + const unknown = "unknown" + + code := unknown + if r.Code != "" { + code = r.Code + } + + message := unknown + if r.Message != "" { + message = r.Message + } + + return code + ":" + message +} + +// securityLogger defines the interface for emitting structured security +// audit events. Implementations are created by an [implFactory] and write +// to a configured sink. +type securityLogger interface { + LogLoggingEnabled() + LogLoggingDisabled() + LogLoginSuccess(user SnapdUser) + LogLoginFailure(user SnapdUser, reason Reason) +} + +// loggerSetup holds the configuration provided to Setup. +type loggerSetup struct { + impl Impl + sink Sink + appID string + minLevel Level +} + +// implFactory provides functions required for constructing a [securityLogger]. +// It is intended for registration of available loggers. +type implFactory interface { + // New creates a securityLogger that writes to writer. Messages with a + // severity below minLevel are silently dropped. + New(writer io.Writer, appID string, minLevel Level) securityLogger +} + +// sinkFactory creates an [io.Writer] for a log output destination. +// The appID identifies the application opening the sink and may be +// used by implementations for tagging or routing. +// +// If the returned writer also implements [io.Closer], it will be closed +// automatically when the sink is replaced or disabled. +type sinkFactory interface { + Open(appID string) (io.Writer, error) +} + +var ( + implementations = map[Impl]implFactory{} + sinks = map[Sink]sinkFactory{} + globalLogger securityLogger = newNopLogger() + globalCloser io.Closer + globalSetup *loggerSetup + writeFailures int + failed bool + lock sync.Mutex +) + +// maxWriteFailures is the number of consecutive write failures +// tolerated before the security logger enters the failed state and +// is automatically disabled. +const maxWriteFailures = 3 + +// Setup stores the logger configuration and attempts to enable the +// security logger immediately. If the log sink cannot be opened (e.g. +// because the journal namespace is not active yet), the configuration +// is still stored and a non-fatal "security logger disabled" error is +// returned. A subsequent call to Enable will re-attempt activation. +// +// Although Setup is reentrant, it is intended to be called exactly +// once per application, typically during early initialization. +func Setup(impl Impl, sink Sink, appID string, minLevel Level) error { + lock.Lock() + defer lock.Unlock() + + if _, exists := implementations[impl]; !exists { + return fmt.Errorf("cannot set up security logger: unknown implementation %q", string(impl)) + } + + if _, exists := sinks[sink]; !exists { + return fmt.Errorf("cannot set up security logger: unknown sink %q", string(sink)) + } + + globalSetup = &loggerSetup{impl: impl, sink: sink, appID: appID, minLevel: minLevel} + if err := enableLocked(); err != nil { + return fmt.Errorf("security logger disabled: %v", err) + } + + return nil +} + +// Enable opens the security log sink using the configuration stored by Setup, +// activating the security logger. If the sink is already open, it is closed +// and re-opened, refreshing the connection to the journal namespace. +// Returns an error if Setup has not been called or if the sink cannot be opened. +func Enable() error { + lock.Lock() + defer lock.Unlock() + + if globalSetup == nil { + return fmt.Errorf("cannot enable security logger: setup has not been called") + } + return enableLocked() +} + +// Disable closes the security log sink and resets the global logger to nop. +// The stored configuration is retained so that Enable can re-open the sink +// later. Returns an error if Setup has not been called or if the sink +// cannot be closed. +func Disable() error { + lock.Lock() + defer lock.Unlock() + + if globalSetup == nil { + return fmt.Errorf("cannot disable security logger: setup has not been called") + } + globalLogger.LogLoggingDisabled() + logger.Noticef("security logger disabled") + return closeSinkLocked() +} + +// LogLoginSuccess logs a successful login using the global security logger. +func LogLoginSuccess(user SnapdUser) { + lock.Lock() + defer lock.Unlock() + + globalLogger.LogLoginSuccess(user) +} + +// LogLoginFailure logs a failed login attempt using the global security logger. +func LogLoginFailure(user SnapdUser, reason Reason) { + lock.Lock() + defer lock.Unlock() + + globalLogger.LogLoginFailure(user, reason) +} + +// registerImpl makes a logger factory available by name. +// The registration pattern allows implementations to be conditionally +// compiled via build tags without requiring the core package to +// import them directly. +// Should be called from the init() of the implementation file. +func registerImpl(name Impl, factory implFactory) { + lock.Lock() + defer lock.Unlock() + + if _, exists := implementations[name]; exists { + panic(fmt.Sprintf("attempting re-registration for existing logger %q", name)) + } + implementations[name] = factory +} + +// registerSink makes a sink factory available by name. +// The registration pattern allows sinks to be conditionally compiled +// via build tags without requiring the core package to import them +// directly. +// Should be called from the init() of the sink file. +func registerSink(name Sink, factory sinkFactory) { + lock.Lock() + defer lock.Unlock() + + if _, exists := sinks[name]; exists { + panic(fmt.Sprintf("attempting re-registration for existing sink %q", name)) + } + sinks[name] = factory +} + +// enableLocked resolves the logger factory, opens the sink, and activates the +// logger. Must be called with lock held and globalSetup non-nil. +func enableLocked() error { + factory, exists := implementations[globalSetup.impl] + if !exists { + return fmt.Errorf("internal error: implementation %q missing", string(globalSetup.impl)) + } + + newSink, exists := sinks[globalSetup.sink] + if !exists { + return fmt.Errorf("internal error: sink %q missing", string(globalSetup.sink)) + } + + writer, err := openSinkLocked(newSink, globalSetup.appID) + if err != nil { + return fmt.Errorf("cannot enable security logger: %w", err) + } + + // Wrap the writer with failure tracking so that repeated write + // errors automatically disable the logger. + tracked := &failureTrackingWriter{ + writer: writer, + writeFailures: &writeFailures, + failed: &failed, + maxFailures: maxWriteFailures, + onThresholdReached: func(failures int, lastErr error) { + logger.Noticef("security logger failed after %d consecutive write errors, disabling (last error: %v)", failures, lastErr) + closeSinkLocked() + }, + } + globalLogger = factory.New(tracked, globalSetup.appID, globalSetup.minLevel) + writeFailures = 0 + failed = false + globalLogger.LogLoggingEnabled() + logger.Noticef("security logger enabled") + return nil +} + +// openSinkLocked opens the log sink and manages the closer. Any previously +// open sink is closed first. Must be called with lock held. +func openSinkLocked(factory sinkFactory, appID string) (io.Writer, error) { + writer, err := factory.Open(appID) + if err != nil { + return nil, err + } + + if globalCloser != nil { + globalCloser.Close() + globalCloser = nil + } + + // If the writer also implements io.Closer, track it so + // the sink is closed when replaced or disabled. + if closer, ok := writer.(io.Closer); ok { + globalCloser = closer + } + + return writer, nil +} + +// closeSinkLocked closes the security log sink and resets the global logger to +// nop. Must be called with lock held. +func closeSinkLocked() error { + globalLogger = newNopLogger() + if globalCloser != nil { + err := globalCloser.Close() + globalCloser = nil + return err + } + return nil +} + +// levelWriter extends [io.Writer] with per-message level control. Writers +// that implement this interface allow log handlers to set the severity for +// each message before writing. +// +// This interface is defined here rather than in slog.go so that +// [failureTrackingWriter] can implement it without a build-tag +// dependency on log/slog. The slog layer's [levelHandler] and the +// audit sink's [auditWriter] are the primary consumers. +type levelWriter interface { + io.Writer + SetLevel(Level) +} + +// failureTrackingWriter wraps an [io.Writer] and counts consecutive +// write failures. When maxFailures consecutive errors are reached it +// invokes onThresholdReached and marks the logger as failed. +// +// All mutable state (writeFailures, failed) is injected via pointers +// so that the writer does not implicitly depend on package globals. +// The caller must hold [lock] when calling Write; since Write is +// invoked from within a locked Log* call, the lock is already held. +type failureTrackingWriter struct { + writer io.Writer + writeFailures *int + failed *bool + maxFailures int + onThresholdReached func(failures int, lastErr error) +} + +func (w *failureTrackingWriter) Write(p []byte) (int, error) { + n, err := w.writer.Write(p) + if err != nil { + *w.writeFailures++ + if *w.writeFailures >= w.maxFailures && !*w.failed { + *w.failed = true + w.onThresholdReached(*w.writeFailures, err) + } + return n, err + } + *w.writeFailures = 0 + return n, nil +} + +// SetLevel implements [levelWriter] so the tracking wrapper is +// transparent to the [levelHandler]. +func (w *failureTrackingWriter) SetLevel(l Level) { + if lw, ok := w.writer.(levelWriter); ok { + lw.SetLevel(l) + } +} diff --git a/seclog/seclog_test.go b/seclog/seclog_test.go new file mode 100644 index 00000000000..c741cd90b91 --- /dev/null +++ b/seclog/seclog_test.go @@ -0,0 +1,630 @@ +// -*- Mode: Go; indent-tabs-mode: t -*- +//go:build go1.21 && !noslog + +/* + * Copyright (C) 2026 Canonical Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 3 as + * published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +package seclog_test + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "testing" + + . "gopkg.in/check.v1" + + "github.com/snapcore/snapd/logger" + "github.com/snapcore/snapd/seclog" + "github.com/snapcore/snapd/testutil" +) + +type SecLogSuite struct { + testutil.BaseTest + buf *bytes.Buffer + appID string +} + +var _ = Suite(&SecLogSuite{}) + +func TestSecLog(t *testing.T) { TestingT(t) } + +func (s *SecLogSuite) SetUpSuite(c *C) { + s.buf = &bytes.Buffer{} + s.appID = "canonical.snapd" +} + +func (s *SecLogSuite) SetUpTest(c *C) { + s.BaseTest.SetUpTest(c) + s.buf.Reset() +} + +func (s *SecLogSuite) TearDownTest(c *C) { + s.BaseTest.TearDownTest(c) +} + +func (s *SecLogSuite) TestString(c *C) { + levels := []seclog.Level{ + seclog.LevelDebug - 1, + seclog.LevelDebug, + seclog.LevelInfo, + seclog.LevelWarn, + seclog.LevelError, + seclog.LevelError + 1, + seclog.LevelCritical, + seclog.LevelCritical + 2, + } + + expected := []string{ + "DEBUG-1", + "DEBUG", + "INFO", + "WARN", + "ERROR", + "CRITICAL", + "CRITICAL", + "CRITICAL+2", + } + + c.Assert(len(levels), Equals, len(expected)) + + obtained := make([]string, 0, len(levels)) + + for _, level := range levels { + obtained = append(obtained, level.String()) + } + + c.Assert(expected, DeepEquals, obtained) +} + +func (s *SecLogSuite) TestSnapdUserString(c *C) { + // All fields set. + c.Check(seclog.SnapdUser{ + ID: 42, StoreUserEmail: "a@b.com", SystemUserName: "jdoe", + }.String(), Equals, "42:a@b.com:jdoe") + + // All fields zero/empty — all "unknown". + c.Check(seclog.SnapdUser{}.String(), Equals, "unknown:unknown:unknown") + + // Only ID set. + c.Check(seclog.SnapdUser{ID: 7}.String(), Equals, "7:unknown:unknown") + + // Only email set. + c.Check(seclog.SnapdUser{StoreUserEmail: "x@y.z"}.String(), Equals, "unknown:x@y.z:unknown") + + // Only username set. + c.Check(seclog.SnapdUser{SystemUserName: "root"}.String(), Equals, "unknown:unknown:root") +} + +func (s *SecLogSuite) TestReasonString(c *C) { + // Both fields set. + c.Check(seclog.Reason{ + Code: seclog.ReasonInvalidCredentials, Message: "bad password", + }.String(), Equals, "invalid-credentials:bad password") + + // Both fields empty — all "unknown". + c.Check(seclog.Reason{}.String(), Equals, "unknown:unknown") + + // Only code set. + c.Check(seclog.Reason{Code: seclog.ReasonInternal}.String(), Equals, "internal:unknown") + + // Only message set. + c.Check(seclog.Reason{Message: "something broke"}.String(), Equals, "unknown:something broke") +} + +func (s *SecLogSuite) TestRegister(c *C) { + restore := seclog.MockImplementations(map[seclog.Impl]seclog.ImplFactory{}) + defer restore() + + seclog.RegisterImpl(seclog.ImplSlog, seclog.SlogImplementation{}) + + // registering the same implementation again panics + c.Assert(func() { seclog.RegisterImpl(seclog.ImplSlog, seclog.SlogImplementation{}) }, PanicMatches, + `attempting re-registration for existing logger "slog"`) +} + +func (s *SecLogSuite) TestRegisterSinkDuplicatePanics(c *C) { + restore := seclog.MockSinks(map[seclog.Sink]seclog.SinkFactory{}) + defer restore() + + dummy := seclog.SinkFunc(func(string) (io.Writer, error) { return nil, nil }) + seclog.RegisterSink(seclog.SinkAudit, dummy) + + // registering the same sink again panics + c.Assert(func() { seclog.RegisterSink(seclog.SinkAudit, dummy) }, PanicMatches, + `attempting re-registration for existing sink "audit"`) +} + +func (s *SecLogSuite) TestSetupUnknownImpl(c *C) { + restore := seclog.MockImplementations(map[seclog.Impl]seclog.ImplFactory{}) + defer restore() + + err := seclog.Setup("unknown", seclog.SinkAudit, s.appID, seclog.LevelInfo) + c.Assert(err, ErrorMatches, + `cannot set up security logger: unknown implementation "unknown"`) +} + +func (s *SecLogSuite) TestSetupUnknownSink(c *C) { + restore := seclog.MockSinks(map[seclog.Sink]seclog.SinkFactory{}) + defer restore() + + err := seclog.Setup(seclog.ImplSlog, "unknown", s.appID, seclog.LevelInfo) + c.Assert(err, ErrorMatches, + `cannot set up security logger: unknown sink "unknown"`) +} + +func (s *SecLogSuite) TestSetupSinkError(c *C) { + restore := seclog.MockNewSink(func(appID string) (io.Writer, error) { + return nil, fmt.Errorf("journal unavailable") + }) + defer restore() + + err := seclog.Setup(seclog.ImplSlog, seclog.SinkAudit, s.appID, seclog.LevelInfo) + c.Assert(err, ErrorMatches, "security logger disabled: cannot enable security logger: journal unavailable") +} + +func (s *SecLogSuite) TestSetupSuccess(c *C) { + restore := seclog.MockNewSink(func(appID string) (io.Writer, error) { + c.Check(appID, Equals, s.appID) + return s.buf, nil + }) + defer restore() + + restoreLogger := seclog.MockGlobalLogger(seclog.NewNopLogger()) + defer restoreLogger() + + err := seclog.Setup(seclog.ImplSlog, seclog.SinkAudit, s.appID, seclog.LevelInfo) + c.Assert(err, IsNil) + + // verify the logger is functional by logging through it + seclog.LogLoginSuccess(seclog.SnapdUser{ID: 1, SystemUserName: "testuser"}) + c.Check(s.buf.Len() > 0, Equals, true) +} + +func (s *SecLogSuite) setupSlogLogger(c *C) { + restore := seclog.MockNewSink(func(appID string) (io.Writer, error) { + return s.buf, nil + }) + s.AddCleanup(restore) + + restoreLogger := seclog.MockGlobalLogger(seclog.NewNopLogger()) + s.AddCleanup(restoreLogger) + + err := seclog.Setup(seclog.ImplSlog, seclog.SinkAudit, s.appID, seclog.LevelInfo) + c.Assert(err, IsNil) + + // Reset buffer after Setup, which logs the "logging enabled" event. + s.buf.Reset() +} + +func (s *SecLogSuite) TestLogLoginSuccess(c *C) { + s.setupSlogLogger(c) + + user := seclog.SnapdUser{ + ID: 42, + StoreUserEmail: "user@example.com", + SystemUserName: "jdoe", + } + seclog.LogLoginSuccess(user) + + var obtained map[string]any + err := json.Unmarshal(s.buf.Bytes(), &obtained) + c.Assert(err, IsNil) + c.Check(obtained["level"], Equals, "INFO") + c.Check(obtained["description"], Equals, + "User 42:user@example.com:jdoe login success") + c.Check(obtained["app_id"], Equals, s.appID) + c.Check(obtained["category"], Equals, "AUTHN") + c.Check(obtained["event"], Equals, "authn_login_success") + userMap, ok := obtained["user"].(map[string]any) + c.Assert(ok, Equals, true) + c.Check(userMap["snapd-user-id"], Equals, float64(42)) + c.Check(userMap["store-user-email"], Equals, "user@example.com") + c.Check(userMap["system-user-name"], Equals, "jdoe") + c.Check(obtained["type"], Equals, "security") +} + +func (s *SecLogSuite) TestLogLoginFailure(c *C) { + s.setupSlogLogger(c) + + user := seclog.SnapdUser{ + ID: 42, + StoreUserEmail: "user@example.com", + SystemUserName: "jdoe", + } + seclog.LogLoginFailure(user, seclog.Reason{Code: seclog.ReasonInvalidCredentials, Message: "invalid credentials"}) + + var obtained map[string]any + err := json.Unmarshal(s.buf.Bytes(), &obtained) + c.Assert(err, IsNil) + c.Check(obtained["level"], Equals, "WARN") + c.Check(obtained["description"], Equals, + "User 42:user@example.com:jdoe login failure: invalid-credentials:invalid credentials") + c.Check(obtained["app_id"], Equals, s.appID) + c.Check(obtained["category"], Equals, "AUTHN") + c.Check(obtained["event"], Equals, "authn_login_failure") + userMap, ok := obtained["user"].(map[string]any) + c.Assert(ok, Equals, true) + c.Check(userMap["snapd-user-id"], Equals, float64(42)) + c.Check(userMap["store-user-email"], Equals, "user@example.com") + c.Check(userMap["system-user-name"], Equals, "jdoe") + errMap, ok := obtained["error"].(map[string]any) + c.Assert(ok, Equals, true) + c.Check(errMap["code"], Equals, seclog.ReasonInvalidCredentials) + c.Check(errMap["message"], Equals, "invalid credentials") + c.Check(obtained["type"], Equals, "security") +} + +// closeTracker is a test helper that records whether Close was called. +type closeTracker struct { + closed bool + err error +} + +func (ct *closeTracker) Close() error { + ct.closed = true + return ct.err +} + +func (s *SecLogSuite) TestDisableClosesTheSink(c *C) { + tracker := &closeTracker{} + restoreCloser := seclog.MockGlobalCloser(tracker) + defer restoreCloser() + restoreLogger := seclog.MockGlobalLogger(seclog.NewNopLogger()) + defer restoreLogger() + restoreSetup := seclog.MockGlobalSetup( + seclog.NewLoggerSetup(seclog.ImplSlog, seclog.SinkAudit, s.appID, seclog.LevelInfo)) + defer restoreSetup() + + err := seclog.Disable() + c.Assert(err, IsNil) + c.Check(tracker.closed, Equals, true) +} + +func (s *SecLogSuite) TestDisableLogsDisabledEvent(c *C) { + s.setupSlogLogger(c) + + err := seclog.Disable() + c.Assert(err, IsNil) + + var obtained map[string]any + err = json.Unmarshal(s.buf.Bytes(), &obtained) + c.Assert(err, IsNil) + c.Check(obtained["level"], Equals, "CRITICAL") + c.Check(obtained["description"], Equals, "Security logging disabled") + c.Check(obtained["category"], Equals, "SYS") + c.Check(obtained["event"], Equals, "sys_logging_disabled") +} + +func (s *SecLogSuite) TestDisableWithNoSetupReturnsError(c *C) { + restoreCloser := seclog.MockGlobalCloser(nil) + defer restoreCloser() + restoreLogger := seclog.MockGlobalLogger(seclog.NewNopLogger()) + defer restoreLogger() + restoreSetup := seclog.MockGlobalSetup(nil) + defer restoreSetup() + + err := seclog.Disable() + c.Assert(err, ErrorMatches, "cannot disable security logger: setup has not been called") +} + +func (s *SecLogSuite) TestEnableWithNoSetupReturnsError(c *C) { + restoreSetup := seclog.MockGlobalSetup(nil) + defer restoreSetup() + + err := seclog.Enable() + c.Assert(err, ErrorMatches, "cannot enable security logger: setup has not been called") +} + +func (s *SecLogSuite) TestEnableWithMissingImpl(c *C) { + restoreSetup := seclog.MockGlobalSetup( + seclog.NewLoggerSetup(seclog.ImplSlog, seclog.SinkAudit, s.appID, seclog.LevelInfo)) + defer restoreSetup() + restoreImpls := seclog.MockImplementations(map[seclog.Impl]seclog.ImplFactory{}) + defer restoreImpls() + + err := seclog.Enable() + c.Assert(err, ErrorMatches, `internal error: implementation "slog" missing`) +} + +func (s *SecLogSuite) TestEnableWithMissingSink(c *C) { + restoreSetup := seclog.MockGlobalSetup( + seclog.NewLoggerSetup(seclog.ImplSlog, seclog.SinkAudit, s.appID, seclog.LevelInfo)) + defer restoreSetup() + restoreSinks := seclog.MockSinks(map[seclog.Sink]seclog.SinkFactory{}) + defer restoreSinks() + + err := seclog.Enable() + c.Assert(err, ErrorMatches, `internal error: sink "audit" missing`) +} + +func (s *SecLogSuite) TestEnableAfterDisable(c *C) { + s.setupSlogLogger(c) + + err := seclog.Disable() + c.Assert(err, IsNil) + s.buf.Reset() + + err = seclog.Enable() + c.Assert(err, IsNil) + s.buf.Reset() + user := seclog.SnapdUser{ + ID: 1, + StoreUserEmail: "a@b.com", + SystemUserName: "u", + } + seclog.LogLoginSuccess(user) + + var obtained map[string]any + err = json.Unmarshal(s.buf.Bytes(), &obtained) + c.Assert(err, IsNil) + c.Check(obtained["event"], Equals, "authn_login_success") +} + +func (s *SecLogSuite) TestDisableIsIdempotent(c *C) { + tracker := &closeTracker{} + restoreCloser := seclog.MockGlobalCloser(tracker) + defer restoreCloser() + restoreLogger := seclog.MockGlobalLogger(seclog.NewNopLogger()) + defer restoreLogger() + restoreSetup := seclog.MockGlobalSetup( + seclog.NewLoggerSetup(seclog.ImplSlog, seclog.SinkAudit, s.appID, seclog.LevelInfo)) + defer restoreSetup() + + err := seclog.Disable() + c.Assert(err, IsNil) + c.Check(tracker.closed, Equals, true) + + // second call does not error even though closer is now nil + err = seclog.Disable() + c.Assert(err, IsNil) +} + +func (s *SecLogSuite) TestDisablePropagatesError(c *C) { + tracker := &closeTracker{err: fmt.Errorf("disk full")} + restoreCloser := seclog.MockGlobalCloser(tracker) + defer restoreCloser() + restoreLogger := seclog.MockGlobalLogger(seclog.NewNopLogger()) + defer restoreLogger() + restoreSetup := seclog.MockGlobalSetup( + seclog.NewLoggerSetup(seclog.ImplSlog, seclog.SinkAudit, s.appID, seclog.LevelInfo)) + defer restoreSetup() + + err := seclog.Disable() + c.Assert(err, ErrorMatches, "disk full") +} + +// writeCloseTracker is a test helper that implements io.WriteCloser and +// records whether Close was called. +type writeCloseTracker struct { + bytes.Buffer + closed bool +} + +func (wc *writeCloseTracker) Close() error { + wc.closed = true + return nil +} + +func (s *SecLogSuite) TestSetupClosesPreviousSink(c *C) { + first := &writeCloseTracker{} + second := &writeCloseTracker{} + call := 0 + restore := seclog.MockNewSink(func(appID string) (io.Writer, error) { + call++ + if call == 1 { + return first, nil + } + return second, nil + }) + defer restore() + restoreCloser := seclog.MockGlobalCloser(nil) + defer restoreCloser() + restoreLogger := seclog.MockGlobalLogger(seclog.NewNopLogger()) + defer restoreLogger() + + // first setup + err := seclog.Setup(seclog.ImplSlog, seclog.SinkAudit, s.appID, seclog.LevelInfo) + c.Assert(err, IsNil) + c.Check(first.closed, Equals, false) + + // second setup should close the first sink + err = seclog.Setup(seclog.ImplSlog, seclog.SinkAudit, s.appID, seclog.LevelInfo) + c.Assert(err, IsNil) + c.Check(first.closed, Equals, true) + c.Check(second.closed, Equals, false) +} + +// countingWriter counts successful writes before switching to errors. +type countingWriter struct { + buf bytes.Buffer + successes int // number of remaining successful writes +} + +func (w *countingWriter) Write(p []byte) (int, error) { + if w.successes > 0 { + w.successes-- + return w.buf.Write(p) + } + return 0, fmt.Errorf("write failed") +} + +func (s *SecLogSuite) TestWriteFailuresDisableAfterThreshold(c *C) { + // Allow LogLoggingEnabled to succeed so writeFailures starts at 0; + // only the test loop writes trigger failures. + cw := &countingWriter{successes: 1} + restore := seclog.MockNewSink(func(appID string) (io.Writer, error) { + return cw, nil + }) + defer restore() + restoreLogger := seclog.MockGlobalLogger(seclog.NewNopLogger()) + defer restoreLogger() + + logBuf, restoreStdLogger := logger.MockLogger() + defer restoreStdLogger() + + err := seclog.Setup(seclog.ImplSlog, seclog.SinkAudit, s.appID, seclog.LevelInfo) + c.Assert(err, IsNil) + logBuf.Reset() + + user := seclog.SnapdUser{ID: 1, SystemUserName: "test"} + + // Exactly maxWriteFailures consecutive failures trigger auto-disable. + for i := 0; i < seclog.MaxWriteFailures; i++ { + seclog.LogLoginSuccess(user) + } + + c.Check(seclog.GetFailed(), Equals, true) + c.Check(seclog.GetWriteFailures(), Equals, seclog.MaxWriteFailures) + c.Check(logBuf.String(), testutil.Contains, + "security logger failed after 3 consecutive write errors, disabling") +} + +func (s *SecLogSuite) TestWriteFailuresDoNotDisableBelowThreshold(c *C) { + // Allow LogLoggingEnabled to succeed so writeFailures starts at 0. + cw := &countingWriter{successes: 1} + restore := seclog.MockNewSink(func(appID string) (io.Writer, error) { + return cw, nil + }) + defer restore() + restoreLogger := seclog.MockGlobalLogger(seclog.NewNopLogger()) + defer restoreLogger() + + err := seclog.Setup(seclog.ImplSlog, seclog.SinkAudit, s.appID, seclog.LevelInfo) + c.Assert(err, IsNil) + + user := seclog.SnapdUser{ID: 1, SystemUserName: "test"} + + // Fewer than maxWriteFailures failures should not trigger auto-disable. + for i := 0; i < seclog.MaxWriteFailures-1; i++ { + seclog.LogLoginSuccess(user) + } + + c.Check(seclog.GetFailed(), Equals, false) + c.Check(seclog.GetWriteFailures(), Equals, seclog.MaxWriteFailures-1) +} + +func (s *SecLogSuite) TestWriteSuccessResetsFailureCount(c *C) { + cw := &countingWriter{successes: 100} + restore := seclog.MockNewSink(func(appID string) (io.Writer, error) { + return cw, nil + }) + defer restore() + restoreLogger := seclog.MockGlobalLogger(seclog.NewNopLogger()) + defer restoreLogger() + + err := seclog.Setup(seclog.ImplSlog, seclog.SinkAudit, s.appID, seclog.LevelInfo) + c.Assert(err, IsNil) + + // Simulate some failures below the threshold. + restoreFailures := seclog.MockWriteFailures(seclog.MaxWriteFailures - 1) + defer restoreFailures() + + user := seclog.SnapdUser{ID: 1, SystemUserName: "test"} + // A successful write resets the counter. + seclog.LogLoginSuccess(user) + + c.Check(seclog.GetWriteFailures(), Equals, 0) + c.Check(seclog.GetFailed(), Equals, false) +} + +func (s *SecLogSuite) TestEnableResetsFailureState(c *C) { + restore := seclog.MockNewSink(func(appID string) (io.Writer, error) { + return s.buf, nil + }) + defer restore() + restoreLogger := seclog.MockGlobalLogger(seclog.NewNopLogger()) + defer restoreLogger() + + err := seclog.Setup(seclog.ImplSlog, seclog.SinkAudit, s.appID, seclog.LevelInfo) + c.Assert(err, IsNil) + + // Simulate a failed state. + restoreFailures := seclog.MockWriteFailures(seclog.MaxWriteFailures) + defer restoreFailures() + restoreFailed := seclog.MockFailed(true) + defer restoreFailed() + + // Re-enable should reset the failure state. + err = seclog.Enable() + c.Assert(err, IsNil) + c.Check(seclog.GetFailed(), Equals, false) + c.Check(seclog.GetWriteFailures(), Equals, 0) +} + +func (s *SecLogSuite) TestEnableLogsToStandardLogger(c *C) { + restore := seclog.MockNewSink(func(appID string) (io.Writer, error) { + return s.buf, nil + }) + defer restore() + restoreLogger := seclog.MockGlobalLogger(seclog.NewNopLogger()) + defer restoreLogger() + + logBuf, restoreStdLogger := logger.MockLogger() + defer restoreStdLogger() + + err := seclog.Setup(seclog.ImplSlog, seclog.SinkAudit, s.appID, seclog.LevelInfo) + c.Assert(err, IsNil) + + c.Check(logBuf.String(), testutil.Contains, "security logger enabled") +} + +func (s *SecLogSuite) TestDisableLogsToStandardLogger(c *C) { + restore := seclog.MockNewSink(func(appID string) (io.Writer, error) { + return s.buf, nil + }) + defer restore() + restoreLogger := seclog.MockGlobalLogger(seclog.NewNopLogger()) + defer restoreLogger() + + err := seclog.Setup(seclog.ImplSlog, seclog.SinkAudit, s.appID, seclog.LevelInfo) + c.Assert(err, IsNil) + + logBuf, restoreStdLogger := logger.MockLogger() + defer restoreStdLogger() + + err = seclog.Disable() + c.Assert(err, IsNil) + + c.Check(logBuf.String(), testutil.Contains, "security logger disabled") +} + +func (s *SecLogSuite) TestFailureTrackingWriterPassesSetLevel(c *C) { + // Use a levelBuf (defined in slog_test.go) which implements + // levelWriter so we can verify SetLevel is called through + // the failureTrackingWriter wrapper. + lb := &levelBuf{} + restore := seclog.MockNewSink(func(appID string) (io.Writer, error) { + return lb, nil + }) + defer restore() + restoreLogger := seclog.MockGlobalLogger(seclog.NewNopLogger()) + defer restoreLogger() + + err := seclog.Setup(seclog.ImplSlog, seclog.SinkAudit, s.appID, seclog.LevelInfo) + c.Assert(err, IsNil) + lb.Reset() + lb.levels = nil + + seclog.LogLoginSuccess(seclog.SnapdUser{ID: 1, SystemUserName: "test"}) + + // The levelHandler should have called SetLevel on the underlying + // levelBuf through the failureTrackingWriter wrapper. + c.Assert(len(lb.levels), Equals, 1) + c.Check(lb.levels[0], Equals, seclog.LevelInfo) +} diff --git a/seclog/slog.go b/seclog/slog.go new file mode 100644 index 00000000000..16458bd29c6 --- /dev/null +++ b/seclog/slog.go @@ -0,0 +1,234 @@ +// -*- Mode: Go; indent-tabs-mode: t -*- + +// go1.21 is required for log/slog which was added in Go 1.21. +// See https://go.dev/doc/go1.21#slog +// The noslog tag allows excluding the slog-based logger entirely. +//go:build go1.21 && !noslog + +/* + * Copyright (C) 2026 Canonical Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 3 as + * published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +package seclog + +import ( + "context" + "fmt" + "io" + "log/slog" + "sync" + "time" + + "github.com/snapcore/snapd/osutil" +) + +// slogImplementation implements [implFactory]. +type slogImplementation struct{} + +// Ensure [slogImplementation] implements [implFactory]. +var _ implFactory = slogImplementation{} + +// New constructs an slog based [securityLogger] that emits structured JSON to the +// provided [io.Writer]. The returned logger enables dynamic level control via +// an internal [slog.LevelVar]. +func (slogImplementation) New(writer io.Writer, appID string, minLevel Level) securityLogger { + return newSlogLogger(writer, appID, minLevel) +} + +func init() { + registerImpl(ImplSlog, slogImplementation{}) +} + +func newSlogLogger(writer io.Writer, appID string, minLevel Level) securityLogger { + levelVar := new(slog.LevelVar) + levelVar.Set(slog.Level(minLevel)) + var handler slog.Handler = newJsonHandler(writer, levelVar) + if lw, ok := writer.(levelWriter); ok { + handler = newLevelHandler(handler, lw) + } + + logger := &slogLogger{ + // enable dynamic level adjustment + levelVar: levelVar, + // always include app_id and type + logger: slog.New(handler).With( + slog.String("app_id", appID), + slog.String("type", "security"), + ), + } + return logger +} + +// slogLogger implements [securityLogger] and is constructed by the +// [slogImplementation]. It wraps a [slog.Logger] and provides the required +// methods. The logger emits structured JSON with a predefined schema for +// built-in attributes and supports dynamic log level control via an internal +// [slog.LevelVar]. When used with a [levelWriter] sink, it ensures that +// each message is written with the correct severity level. +type slogLogger struct { + logger *slog.Logger + levelVar *slog.LevelVar +} + +// Ensure [slogLogger] implements [securityLogger]. +var _ securityLogger = (*slogLogger)(nil) + +// SlogLogger is a test only helper to retrieve a pointer to the underlying +// [slog.Logger]. +func (l *slogLogger) SlogLogger() *slog.Logger { + osutil.MustBeTestBinary("SlogLogger() is for testing only") + return l.logger +} + +// LogLoggingEnabled implements [securityLogger.LogLoggingEnabled]. +func (l *slogLogger) LogLoggingEnabled() { + l.logger.LogAttrs( + context.Background(), + slog.Level(LevelInfo), + "Security logging enabled", + slog.Attr{Key: "category", Value: slog.StringValue("SYS")}, + slog.Attr{Key: "event", Value: slog.StringValue("sys_logging_enabled")}, + ) +} + +// LogLoggingDisabled implements [securityLogger.LogLoggingDisabled]. +func (l *slogLogger) LogLoggingDisabled() { + l.logger.LogAttrs( + context.Background(), + slog.Level(LevelCritical), + "Security logging disabled", + slog.Attr{Key: "category", Value: slog.StringValue("SYS")}, + slog.Attr{Key: "event", Value: slog.StringValue("sys_logging_disabled")}, + ) +} + +// LogLoginSuccess implements [securityLogger.LogLoginSuccess]. +func (l *slogLogger) LogLoginSuccess(user SnapdUser) { + l.logger.LogAttrs( + context.Background(), + slog.Level(LevelInfo), + fmt.Sprintf("User %s login success", user.String()), + slog.Attr{Key: "category", Value: slog.StringValue("AUTHN")}, + slog.Attr{Key: "event", Value: slog.StringValue("authn_login_success")}, + slog.Any("user", user), + ) +} + +// LogLoginFailure implements [securityLogger.LogLoginFailure]. +func (l *slogLogger) LogLoginFailure(user SnapdUser, reason Reason) { + l.logger.LogAttrs( + context.Background(), + slog.Level(LevelWarn), + fmt.Sprintf("User %s login failure: %s", user.String(), reason.String()), + slog.Attr{Key: "category", Value: slog.StringValue("AUTHN")}, + slog.Attr{Key: "event", Value: slog.StringValue("authn_login_failure")}, + slog.Any("user", user), + slog.Any("error", reason), + ) +} + +// LogValue implements [slog.LogValuer], allowing SnapdUser to be +// used directly as a structured log attribute value. +func (u SnapdUser) LogValue() slog.Value { + expiration := "never" + if !u.Expiration.IsZero() { + expiration = u.Expiration.UTC().Format(time.RFC3339Nano) + } + return slog.GroupValue( + slog.Int64("snapd-user-id", u.ID), + slog.String("system-user-name", u.SystemUserName), + slog.String("store-user-email", u.StoreUserEmail), + slog.String("expiration", expiration), + ) +} + +// newJsonHandler returns a slog JSON handler configured for security logs. +// +// It writes newline-delimited JSON to writer and enforces a schema for the +// built-in attributes: +// - time: key "datetime", formatted in UTC using [time.RFC3339Nano] +// - level: rendered as a string via [Level.String] +// - message: key "description" +// - app_id: always included with the value provided to newSlogLogger +// - type: always included with the value "security" +// +// Additional attributes are preserved verbatim, including nested groups. The +// handler logs at or above the minLevel threshold. It does not +// close or sync writer. +func newJsonHandler(writer io.Writer, minLevel slog.Leveler) slog.Handler { + options := &slog.HandlerOptions{ + Level: minLevel, + ReplaceAttr: func(groups []string, attr slog.Attr) slog.Attr { + switch attr.Key { + case slog.TimeKey: + // use "datetime" instead of default "time" + attr.Key = "datetime" + if t, ok := attr.Value.Any().(time.Time); ok { + // convert to formatted string + attr.Value = slog.StringValue(t.UTC().Format(time.RFC3339Nano)) + } + case slog.LevelKey: + if l, ok := attr.Value.Any().(slog.Level); ok { + attr.Value = slog.StringValue(Level(l).String()) + } + case slog.MessageKey: + // use "description" instead of default "msg" + attr.Key = "description" + } + return attr + }, + } + + return slog.NewJSONHandler(writer, options) +} + +// levelHandler is a [slog.Handler] wrapper that sets the level on a +// [levelWriter] before each message is handled. This ensures that the +// written output carries the correct per-message priority. +// +// All derived handlers returned by WithAttrs and WithGroup share the same +// [levelWriter] and mutex, since they write to the same sink. +type levelHandler struct { + inner slog.Handler + lw levelWriter + mu *sync.Mutex +} + +func newLevelHandler(inner slog.Handler, lw levelWriter) slog.Handler { + return &levelHandler{inner: inner, lw: lw, mu: &sync.Mutex{}} +} + +func (h *levelHandler) Enabled(ctx context.Context, level slog.Level) bool { + return h.inner.Enabled(ctx, level) +} + +func (h *levelHandler) Handle(ctx context.Context, r slog.Record) error { + h.mu.Lock() + defer h.mu.Unlock() + + h.lw.SetLevel(Level(r.Level)) + return h.inner.Handle(ctx, r) +} + +func (h *levelHandler) WithAttrs(attrs []slog.Attr) slog.Handler { + return &levelHandler{inner: h.inner.WithAttrs(attrs), lw: h.lw, mu: h.mu} +} + +// WithGroup is required by the [slog.Handler] interface but is not +// currently used by seclog. +func (h *levelHandler) WithGroup(name string) slog.Handler { + return &levelHandler{inner: h.inner.WithGroup(name), lw: h.lw, mu: h.mu} +} diff --git a/seclog/slog_test.go b/seclog/slog_test.go new file mode 100644 index 00000000000..02711232db8 --- /dev/null +++ b/seclog/slog_test.go @@ -0,0 +1,332 @@ +// -*- Mode: Go; indent-tabs-mode: t -*- +//go:build go1.21 && !noslog + +/* + * Copyright (C) 2026 Canonical Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 3 as + * published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +package seclog_test + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "time" + + "log/slog" + + . "gopkg.in/check.v1" + + "github.com/snapcore/snapd/seclog" + "github.com/snapcore/snapd/testutil" +) + +type SlogSuite struct { + testutil.BaseTest + buf *bytes.Buffer + appID string + factory seclog.ImplFactory +} + +var _ = Suite(&SlogSuite{}) + +func (s *SlogSuite) SetUpSuite(c *C) { + s.buf = &bytes.Buffer{} + s.appID = "canonical.snapd" + s.factory = seclog.SlogImplementation{} +} + +func (s *SlogSuite) SetUpTest(c *C) { + s.BaseTest.SetUpTest(c) + s.buf.Reset() +} + +func (s *SlogSuite) TearDownTest(c *C) { + s.BaseTest.TearDownTest(c) +} + +// extractSlogLogger is a test helper to extract the internal [slog.Logger] from +// SecurityLogger. +func extractSlogLogger(logger seclog.SecurityLogger) (*slog.Logger, error) { + if l, ok := logger.(*seclog.SlogLogger); !ok { + return nil, errors.New("cannot extract slog logger") + } else { + // return the internal slog logger + return l.SlogLogger(), nil + } +} + +func (s *SlogSuite) TestSlogImplementation(c *C) { + logger := s.factory.New(s.buf, s.appID, seclog.LevelInfo) + c.Check(logger, NotNil) +} + +// baseAttrs represents the non-optional attributes that is present in +// every record +type baseAttrs struct { + Datetime time.Time `json:"datetime"` + Level string `json:"level"` + Description string `json:"description"` + AppID string `json:"app_id"` + Type string `json:"type"` + Category string `json:"category"` +} + +// orderedKeys extracts the top-level JSON object keys in order. +func orderedKeys(data []byte) ([]string, error) { + decoder := json.NewDecoder(bytes.NewReader(data)) + // consume opening '{' + token, err := decoder.Token() + if err != nil { + return nil, err + } + if delim, ok := token.(json.Delim); !ok || delim != '{' { + return nil, errors.New("expected '{' delimiter") + } + var keys []string + for decoder.More() { + token, err = decoder.Token() + if err != nil { + return nil, err + } + key, ok := token.(string) + if !ok { + return nil, errors.New("expected string key") + } + keys = append(keys, key) + // skip value + var raw json.RawMessage + if err := decoder.Decode(&raw); err != nil { + return nil, err + } + } + return keys, nil +} + +type attrsAllTypes struct { + baseAttrs + String string `json:"string"` + Duration time.Duration `json:"duration"` + Timestamp time.Time `json:"timestamp"` + Float64 float64 `json:"float64"` + Int64 int64 `json:"int64"` + Int int64 `json:"int"` + Uint64 uint64 `json:"uint64"` + Any any `json:"any"` +} + +func (s *SlogSuite) TestHandlerAttrsAllTypes(c *C) { + logger := s.factory.New(s.buf, s.appID, seclog.LevelInfo) + c.Assert(logger, NotNil) + + sl, err := extractSlogLogger(logger) + c.Assert(err, IsNil) + sl.LogAttrs( + context.Background(), + slog.Level(seclog.LevelInfo), + "test description", + slog.Attr{Key: "category", Value: slog.StringValue("AUTHN")}, + slog.Attr{Key: "string", Value: slog.StringValue("test string")}, + slog.Attr{Key: "duration", Value: slog.DurationValue(time.Duration(90 * time.Second))}, + slog.Attr{ + Key: "timestamp", + Value: slog.TimeValue(time.Date(2025, 10, 8, 8, 0, 0, 0, time.UTC)), + }, + slog.Attr{Key: "float64", Value: slog.Float64Value(3.141592653589793)}, + slog.Attr{Key: "int64", Value: slog.Int64Value(-4611686018427387904)}, + slog.Attr{Key: "int", Value: slog.IntValue(-4294967296)}, + slog.Attr{Key: "uint64", Value: slog.Uint64Value(4294967295)}, + // AnyValue returns value of KindInt64, the original + // numeric type is not preserved + slog.Attr{Key: "any", Value: slog.AnyValue(map[string]any{"k": "v", "n": int(1)})}, + ) + + var obtained attrsAllTypes + err = json.Unmarshal(s.buf.Bytes(), &obtained) + c.Assert(err, IsNil) + + c.Check(time.Since(obtained.Datetime) < time.Second, Equals, true) + c.Check(obtained.Level, Equals, "INFO") + c.Check(obtained.Description, Equals, "test description") + c.Check(obtained.AppID, Equals, s.appID) + c.Check(obtained.Type, Equals, "security") + c.Check(obtained.Category, Equals, "AUTHN") + + c.Check(obtained.String, Equals, "test string") + c.Check(obtained.Duration, Equals, time.Duration(90*time.Second)) + c.Check(obtained.Timestamp, Equals, time.Date(2025, 10, 8, 8, 0, 0, 0, time.UTC)) + c.Check(obtained.Float64, Equals, float64(3.141592653589793)) + c.Check(obtained.Int64, Equals, int64(-4611686018427387904)) + c.Check(obtained.Int, Equals, int64(-4294967296)) + c.Check(obtained.Uint64, Equals, uint64(4294967295)) + c.Check(obtained.Any, DeepEquals, map[string]any{"k": "v", "n": float64(1)}) +} + +func (s *SlogSuite) TestLogLoginSuccess(c *C) { + logger := s.factory.New(s.buf, s.appID, seclog.LevelInfo) + c.Assert(logger, NotNil) + + type LoginSuccess struct { + baseAttrs + Event string `json:"event"` + User struct { + ID int64 `json:"snapd-user-id"` + SystemUserName string `json:"system-user-name"` + StoreUserEmail string `json:"store-user-email"` + Expiration string `json:"expiration"` + } `json:"user"` + } + + user := seclog.SnapdUser{ + ID: 42, + StoreUserEmail: "user@gmail.com", + SystemUserName: "jdoe", + } + logger.LogLoginSuccess(user) + + var obtained LoginSuccess + err := json.Unmarshal(s.buf.Bytes(), &obtained) + c.Assert(err, IsNil) + c.Check(time.Since(obtained.Datetime) < time.Second, Equals, true) + c.Check(obtained.Level, Equals, "INFO") + c.Check(obtained.Description, Equals, "User 42:user@gmail.com:jdoe login success") + c.Check(obtained.AppID, Equals, s.appID) + c.Check(obtained.Event, Equals, "authn_login_success") + c.Check(obtained.User.ID, Equals, int64(42)) + c.Check(obtained.User.StoreUserEmail, Equals, "user@gmail.com") + c.Check(obtained.User.SystemUserName, Equals, "jdoe") + c.Check(obtained.User.Expiration, Equals, "never") + + // verify key order for human readability + keys, err := orderedKeys(s.buf.Bytes()) + c.Assert(err, IsNil) + c.Check(keys, DeepEquals, []string{ + "datetime", "level", "description", + "app_id", "type", "category", "event", "user", + }) +} + +func (s *SlogSuite) TestLogLoginSuccessWithExpiration(c *C) { + logger := s.factory.New(s.buf, s.appID, seclog.LevelInfo) + c.Assert(logger, NotNil) + + type LoginSuccess struct { + baseAttrs + Event string `json:"event"` + User struct { + ID int64 `json:"snapd-user-id"` + SystemUserName string `json:"system-user-name"` + StoreUserEmail string `json:"store-user-email"` + Expiration string `json:"expiration"` + } `json:"user"` + } + + expiry := time.Date(2026, 6, 15, 12, 0, 0, 0, time.UTC) + user := seclog.SnapdUser{ + ID: 42, + StoreUserEmail: "user@gmail.com", + SystemUserName: "jdoe", + Expiration: expiry, + } + logger.LogLoginSuccess(user) + + var obtained LoginSuccess + err := json.Unmarshal(s.buf.Bytes(), &obtained) + c.Assert(err, IsNil) + c.Check(obtained.User.Expiration, Equals, "2026-06-15T12:00:00Z") +} + +func (s *SlogSuite) TestLogLoginFailure(c *C) { + logger := s.factory.New(s.buf, s.appID, seclog.LevelInfo) + c.Assert(logger, NotNil) + + type loginFailure struct { + baseAttrs + Event string `json:"event"` + User struct { + ID int64 `json:"snapd-user-id"` + SystemUserName string `json:"system-user-name"` + StoreUserEmail string `json:"store-user-email"` + Expiration string `json:"expiration"` + } `json:"user"` + Error struct { + Code string `json:"code"` + Message string `json:"message"` + } `json:"error"` + } + + user := seclog.SnapdUser{ + ID: 42, + StoreUserEmail: "user@gmail.com", + SystemUserName: "jdoe", + } + logger.LogLoginFailure(user, seclog.Reason{Code: seclog.ReasonInvalidCredentials, Message: "invalid credentials"}) + + var obtained loginFailure + err := json.Unmarshal(s.buf.Bytes(), &obtained) + c.Assert(err, IsNil) + c.Check(time.Since(obtained.Datetime) < time.Second, Equals, true) + c.Check(obtained.Level, Equals, "WARN") + c.Check(obtained.Description, Equals, "User 42:user@gmail.com:jdoe login failure: invalid-credentials:invalid credentials") + c.Check(obtained.AppID, Equals, s.appID) + c.Check(obtained.Event, Equals, "authn_login_failure") + c.Check(obtained.User.ID, Equals, int64(42)) + c.Check(obtained.User.StoreUserEmail, Equals, "user@gmail.com") + c.Check(obtained.User.SystemUserName, Equals, "jdoe") + c.Check(obtained.User.Expiration, Equals, "never") + c.Check(obtained.Error.Code, Equals, seclog.ReasonInvalidCredentials) + c.Check(obtained.Error.Message, Equals, "invalid credentials") + + // verify key order for human readability + keys, err := orderedKeys(s.buf.Bytes()) + c.Assert(err, IsNil) + c.Check(keys, DeepEquals, []string{ + "datetime", "level", "description", + "app_id", "type", "category", "event", "user", "error", + }) +} + +// levelBuf is a bytes.Buffer that also implements [seclog.LevelWriter], +// recording the level set before each log message is written. +type levelBuf struct { + bytes.Buffer + levels []seclog.Level +} + +func (lb *levelBuf) SetLevel(l seclog.Level) { + lb.levels = append(lb.levels, l) +} + +// Ensure levelBuf satisfies the interface. +var _ seclog.LevelWriter = (*levelBuf)(nil) + +func (s *SlogSuite) TestLevelHandlerSetsLevelBeforeWrite(c *C) { + lb := &levelBuf{} + logger := seclog.SlogImplementation{}.New(lb, s.appID, seclog.LevelInfo) + + slogLogger, err := extractSlogLogger(logger) + c.Assert(err, IsNil) + + // Use seclog level values cast to slog.Level so they pass the + // level threshold set by newSlogLogger (slog.Level(seclog.LevelInfo)). + slogLogger.Log(context.Background(), slog.Level(seclog.LevelInfo), "info message") + slogLogger.Log(context.Background(), slog.Level(seclog.LevelWarn), "warn message") + + c.Assert(len(lb.levels), Equals, 2) + c.Check(lb.levels[0], Equals, seclog.LevelInfo) + c.Check(lb.levels[1], Equals, seclog.LevelWarn) +} diff --git a/snap/naming/validate.go b/snap/naming/validate.go index 7ce54e27f99..ab3ebed0842 100644 --- a/snap/naming/validate.go +++ b/snap/naming/validate.go @@ -315,6 +315,15 @@ func validateAssumedSnapdVersion(assumedVersion, currentVersion string) (bool, e var archIsISASupportedByCPU = arch.IsISASupportedByCPU +type IsaError struct { + Flag string + Err error +} + +func (e *IsaError) Error() string { + return fmt.Sprintf("%s: %s", e.Flag, e.Err) +} + // validateAssumedISAArch checks that, when a snap requires an ISA to be supported: // 1. compares the specified with the device's one. If they differ, it exits // without error signaling that the flag is valid @@ -338,7 +347,7 @@ func validateAssumedISAArch(flag string, currentArchitecture string) error { } if err := archIsISASupportedByCPU(tokens[2]); err != nil { - return fmt.Errorf("%s: %s", flag, err) + return &IsaError{Flag: flag, Err: err} } return nil diff --git a/tests/core/enable-disable-units-gpio/task.yaml b/tests/core/enable-disable-units-gpio/task.yaml index 22e4a0ef6c9..a474b4e3497 100644 --- a/tests/core/enable-disable-units-gpio/task.yaml +++ b/tests/core/enable-disable-units-gpio/task.yaml @@ -21,7 +21,7 @@ skip: prepare: | echo "Create/enable fake gpio" - tests.systemd create-and-start-unit fake-gpio "$TESTSLIB/fakegpio/fake-gpio.py" "[Unit]\\nBefore=snap.snapd.interface.gpio-100.service\\n[Service]\\nType=notify" + tests.systemd create-and-start-unit fake-gpio "$TESTSLIB/fakegpio/fake-gpio.py" "[Unit]\\nBefore=snap.snapd.interface.gpio-100.service\\n[Service]\\nType=notify\\nEnvironment=PATH=$PATH" echo "Given a snap declaring a plug on gpio is installed" "$TESTSTOOLS"/snaps-state install-local gpio-consumer diff --git a/tests/lib/nested.sh b/tests/lib/nested.sh index 55cad68a63f..c99e70ddd30 100755 --- a/tests/lib/nested.sh +++ b/tests/lib/nested.sh @@ -1863,7 +1863,7 @@ nested_prepare_tools() { if ! remote.exec "grep -qE PATH=.*$TOOLS_PATH /etc/environment"; then # shellcheck disable=SC2016 REMOTE_PATH="$(remote.exec 'echo $PATH')" - remote.exec "echo PATH=$TOOLS_PATH:$REMOTE_PATH | sudo tee -a /etc/environment" + remote.exec "echo PATH=$TOOLS_PATH:$REMOTE_PATH:/usr/lib/python | sudo tee -a /etc/environment" fi if [ -n "$TAG_FEATURES" ]; then diff --git a/tests/lib/prepare.sh b/tests/lib/prepare.sh index eaa03c754f0..3441b825ce2 100755 --- a/tests/lib/prepare.sh +++ b/tests/lib/prepare.sh @@ -1405,6 +1405,8 @@ setup_reflash_magic() { snap tasks --last=seed || true journalctl -u snapd snap model --verbose + #shellcheck source=tests/lib/nested.sh + . "$TESTSLIB/nested.sh" # remove the above debug lines once the mentioned bug is fixed snap install "--channel=$(nested_get_base_channel)" "$core_name" # TODO set up a trap to clean this up properly? diff --git a/tests/main/disk-space-awareness/task.yaml b/tests/main/disk-space-awareness/task.yaml index a11cd739af8..304a7c63f1e 100644 --- a/tests/main/disk-space-awareness/task.yaml +++ b/tests/main/disk-space-awareness/task.yaml @@ -11,18 +11,26 @@ environment: TMPFSMOUNT: /var/lib/snapd # filling tmpfs mounted under /var/lib/snapd triggers OOM SNAPD_NO_MEMORY_LIMIT: 1 - SUFFICIENT_SIZE: 200M + SUFFICIENT_SIZE: 300M prepare: | systemctl stop snapd.{socket,service} + SNAP_MOUNT_DIR="$(os.paths snap-mount-dir)" + # purge removes the snap mount directory, which needs to be restored + snapd.tool exec snap-mgmt --purge + mkdir -p "$SNAP_MOUNT_DIR" # mount /var/lib/snapd on a tmpfs mount -t tmpfs tmpfs -o size="$SUFFICIENT_SIZE",mode=0755 "$TMPFSMOUNT" systemctl start snapd.{socket,service} + snap wait system seed.loaded restore: | systemctl stop snapd.{socket,service} + SNAP_MOUNT_DIR="$(os.paths snap-mount-dir)" + snapd.tool exec snap-mgmt --purge + mkdir -p "$SNAP_MOUNT_DIR" umount -l "$TMPFSMOUNT" systemctl start snapd.{socket,service} diff --git a/tests/main/security-logging/task.yaml b/tests/main/security-logging/task.yaml new file mode 100644 index 00000000000..2848b09cd32 --- /dev/null +++ b/tests/main/security-logging/task.yaml @@ -0,0 +1,63 @@ +summary: Checks that security audit events are written to the kernel audit log + +details: | + The snapd daemon writes structured security audit events via the kernel + audit subsystem (AUDIT_TRUSTED_APP, type 1121). This test verifies that + a failed login attempt produces an "authn_login_failure" event and, + when store credentials are available, that a successful login produces + an "authn_login_success" event in the audit log. + +# ubuntu-core: auditd is not available as a distro package +systems: [-ubuntu-core-*] + +prepare: | + # Ensure auditd (which provides ausearch) is installed and running. + if ! command -v ausearch; then + #shellcheck source=tests/lib/pkgdb.sh + . "$TESTSLIB/pkgdb.sh" + distro_install_package auditd + systemctl enable --now auditd.service + fi + + # Create an audit checkpoint so we only see events from this test. + ausearch --checkpoint stamp -m 1121 || true + +restore: | + snap logout || true + rm -f stamp + +execute: | + echo "Checking that the security logger initialised" + "$TESTSTOOLS"/journal-state match-log "security logger enabled" + + echo "Checking that a failed login attempt produces an audit event" + echo '{"email":"someemail@testing.com","password":"wrong-password"}' | \ + snap debug api -X POST -H 'Content-Type: application/json' /v2/login || true + + # The audit log entry is the raw JSON payload sent by snapd. + ausearch --start checkpoint --checkpoint stamp -m 1121 --raw 2>&1 | MATCH 'authn_login_failure' + ausearch --start checkpoint --checkpoint stamp -m 1121 --raw 2>&1 | MATCH 'invalid-credentials' + ausearch --start checkpoint --checkpoint stamp -m 1121 --raw 2>&1 | MATCH 'someemail@testing.com' + + if [ -n "$SPREAD_STORE_USER" ] && [ -n "$SPREAD_STORE_PASSWORD" ]; then + echo "Checking that a successful login produces an audit event" + # Reset the checkpoint so we only see the success event. + ausearch --checkpoint stamp -m 1121 || true + + expect -d -f "$TESTSLIB"/successful_login.exp + + ausearch --start checkpoint --checkpoint stamp -m 1121 --raw 2>&1 | MATCH 'authn_login_success' + ausearch --start checkpoint --checkpoint stamp -m 1121 --raw 2>&1 | MATCH "$SPREAD_STORE_USER" + + snap logout + fi + + echo "Checking that restart produces disable and enable audit events" + ausearch --checkpoint stamp -m 1121 || true + systemctl restart snapd.service + snap wait system seed.loaded + + "$TESTSTOOLS"/journal-state match-log "security logger disabled" + "$TESTSTOOLS"/journal-state match-log "security logger enabled" + ausearch --start checkpoint --checkpoint stamp -m 1121 --raw 2>&1 | MATCH 'sys_logging_disabled' + ausearch --start checkpoint --checkpoint stamp -m 1121 --raw 2>&1 | MATCH 'sys_logging_enabled' diff --git a/tests/main/snapctl-is-ready/task.yaml b/tests/main/snapctl-is-ready/task.yaml new file mode 100644 index 00000000000..ef8f81f326a --- /dev/null +++ b/tests/main/snapctl-is-ready/task.yaml @@ -0,0 +1,48 @@ +summary: Ensure that snapctl is-ready command works. + +details: | + Verifies that the snapctl is-ready command correctly reports the status of a + change initiated by a snap via snapctl. A test snap with a component is + installed locally. The component is then removed via snapctl (from within the + snap's app context using snap run --shell), which creates a snapctl-remove + change marked with the initiated-by-snap change key for the calling snap. + The implementation also tracks last-access information in the in-memory state + cache rather than as a change attribute. is-ready is then called against that + change ID to verify it reports Done and exits successfully. + + Also verifies that is-ready fails appropriately for invalid change IDs, wrong + argument counts, and changes not initiated by the calling snap. + +systems: [ubuntu-16.04-64, ubuntu-18.04-64, ubuntu-2*, ubuntu-core-*, fedora-*] + +prepare: | + snap pack test-snap/ + snap pack test-comp/ + snap install --dangerous test-snapctl-is-ready_1.0_all.snap + snap install --dangerous test-snapctl-is-ready+comp_1.0.comp + +execute: | + echo "Remove component via snapctl to create a snapctl-remove change" + snap run test-snapctl-is-ready.app snapctl remove +comp + + CHANGE_ID=$(snap debug api /v2/changes?select=all | \ + gojq --raw-output '[.result[] | select(.kind == "snapctl-remove")] | last | .id') + + test -n "$CHANGE_ID" + test "$CHANGE_ID" != "null" + + echo "snapctl is-ready exits 0 for a completed change (stdout is empty; exit code conveys status)" + snap run test-snapctl-is-ready.app snapctl is-ready "$CHANGE_ID" + + echo "snapctl is-ready fails with exit 3 for an invalid change ID" + snap run test-snapctl-is-ready.app snapctl is-ready nonexistent-id || test $? -eq 3 + + echo "snapctl is-ready fails with too few arguments" + not snap run test-snapctl-is-ready.app snapctl is-ready + + echo "snapctl is-ready fails with too many arguments" + not snap run test-snapctl-is-ready.app snapctl is-ready "$CHANGE_ID" extra-arg + + echo "snapctl is-ready fails for a change not initiated by the snap" + INSTALL_CHANGE_ID=$(snap install --no-wait test-snapd-tools) + snap run test-snapctl-is-ready.app snapctl is-ready "$INSTALL_CHANGE_ID" || test $? -eq 3 diff --git a/tests/main/snapctl-is-ready/test-comp/meta/component.yaml b/tests/main/snapctl-is-ready/test-comp/meta/component.yaml new file mode 100644 index 00000000000..64b8573c6ab --- /dev/null +++ b/tests/main/snapctl-is-ready/test-comp/meta/component.yaml @@ -0,0 +1,5 @@ +component: test-snapctl-is-ready+comp +type: standard +version: 1.0 +summary: Test component for snapctl is-ready +description: Test component for snapctl is-ready diff --git a/tests/main/snapctl-is-ready/test-snap/bin/app b/tests/main/snapctl-is-ready/test-snap/bin/app new file mode 100755 index 00000000000..311cb8cb40c --- /dev/null +++ b/tests/main/snapctl-is-ready/test-snap/bin/app @@ -0,0 +1,2 @@ +#!/bin/sh +exec "$@" diff --git a/tests/main/snapctl-is-ready/test-snap/meta/snap.yaml b/tests/main/snapctl-is-ready/test-snap/meta/snap.yaml new file mode 100644 index 00000000000..121ab5f0d5d --- /dev/null +++ b/tests/main/snapctl-is-ready/test-snap/meta/snap.yaml @@ -0,0 +1,12 @@ +name: test-snapctl-is-ready +version: 1.0 +summary: Test snap for snapctl is-ready +apps: + app: + command: bin/app +base: core24 +components: + comp: + summary: test component for snapctl is-ready + description: test component for snapctl is-ready + type: standard diff --git a/tests/main/upgrade-from-release/task.yaml b/tests/main/upgrade-from-release/task.yaml index 6df7040c58f..b20c72b3ed3 100644 --- a/tests/main/upgrade-from-release/task.yaml +++ b/tests/main/upgrade-from-release/task.yaml @@ -44,7 +44,7 @@ execute: | # TODO: add automatic package lookup - manual list maintenance is impractical declare -A EXPECTED_SNAPD_VERSIONS=( ["26.04"]='2.74.1\+ubuntu26.04' - ["25.10"]='2.73\+ubuntu25.10' + ["25.10"]='2.74.1\+ubuntu25.10' ["24.04"]='2.62\+24.04' ["22.04"]='2.55.3\+22.04' ["20.04"]='2.44.3\+20.04' diff --git a/tests/nested/classic/azure-cvm/task.yaml b/tests/nested/classic/azure-cvm/task.yaml index 69c247262bb..405059ccd29 100644 --- a/tests/nested/classic/azure-cvm/task.yaml +++ b/tests/nested/classic/azure-cvm/task.yaml @@ -9,8 +9,6 @@ systems: - -ubuntu-16.04-* - -ubuntu-18.04-* - -ubuntu-20.04-* - # FIXME - - -ubuntu-26.04-* environment: SNAPD_DEB_FROM_REPO: false diff --git a/tests/nested/core/interfaces-custom-devices/task.yaml b/tests/nested/core/interfaces-custom-devices/task.yaml index 9ec3a22507a..ed0a918a5f1 100644 --- a/tests/nested/core/interfaces-custom-devices/task.yaml +++ b/tests/nested/core/interfaces-custom-devices/task.yaml @@ -5,10 +5,10 @@ details: | granting access to the devices it defines. systems: - # FIXME: make it work on 26 - ubuntu-20* - ubuntu-22* - ubuntu-24* + - ubuntu-26* prepare: | # Add our interface to the gadget snap diff --git a/tests/nested/manual/component-recovery-system-offline/task.yaml b/tests/nested/manual/component-recovery-system-offline/task.yaml index 03dc50501b4..0a0daeb635e 100644 --- a/tests/nested/manual/component-recovery-system-offline/task.yaml +++ b/tests/nested/manual/component-recovery-system-offline/task.yaml @@ -9,8 +9,8 @@ details: | HTTP form. systems: - # FIXME: make it work on 26 - ubuntu-24* + - ubuntu-26* environment: MODEL_JSON: $TESTSLIB/assertions/test-snapd-component-recovery-system-pc-VERSION.json diff --git a/tests/nested/manual/component-recovery-system/task.yaml b/tests/nested/manual/component-recovery-system/task.yaml index 4f4dd0ebcd9..762673728e5 100644 --- a/tests/nested/manual/component-recovery-system/task.yaml +++ b/tests/nested/manual/component-recovery-system/task.yaml @@ -5,8 +5,8 @@ details: | validates that the newly created system can be rebooted into. systems: - # FIXME: make it work on 26 - ubuntu-24* + - ubuntu-26* environment: MODEL_JSON: $TESTSLIB/assertions/test-snapd-component-recovery-system-pc-VERSION.json diff --git a/tests/nested/manual/core20-fault-inject-on-install-component/task.yaml b/tests/nested/manual/core20-fault-inject-on-install-component/task.yaml index 60c2ef205cd..31a0306ffd8 100644 --- a/tests/nested/manual/core20-fault-inject-on-install-component/task.yaml +++ b/tests/nested/manual/core20-fault-inject-on-install-component/task.yaml @@ -8,8 +8,6 @@ systems: - -ubuntu-1* - -ubuntu-20* - -ubuntu-22* - # FIXME - - -ubuntu-26* environment: TAG/kernel_panic_prepare_kernel_components: prepare-kernel-components diff --git a/tests/nested/manual/seeding-failure/task.yaml b/tests/nested/manual/seeding-failure/task.yaml index e9a78e63313..f3f8d28625c 100644 --- a/tests/nested/manual/seeding-failure/task.yaml +++ b/tests/nested/manual/seeding-failure/task.yaml @@ -14,8 +14,6 @@ systems: - -ubuntu-1* - -ubuntu-20* - -ubuntu-22* - # FIXME - - -ubuntu-26* environment: MODEL_JSON: $TESTSLIB/assertions/test-snapd-failed-seeding-pc-VERSION.json diff --git a/tests/nested/manual/snapd-removes-vulnerable-snap-confine-revs/task.yaml b/tests/nested/manual/snapd-removes-vulnerable-snap-confine-revs/task.yaml index 3829166a64d..04031878b3c 100644 --- a/tests/nested/manual/snapd-removes-vulnerable-snap-confine-revs/task.yaml +++ b/tests/nested/manual/snapd-removes-vulnerable-snap-confine-revs/task.yaml @@ -7,10 +7,10 @@ details: | # just focal is fine for this test - we only need to check that things happen on # classic systems: - # FIXME: make it work on 26 - ubuntu-20* - ubuntu-22* - ubuntu-24* + - ubuntu-26* environment: # which snap snapd comes from in this test diff --git a/tests/regression/lp-1813365/task.yaml b/tests/regression/lp-1813365/task.yaml index 9c814f7ce0c..c94e78b123f 100644 --- a/tests/regression/lp-1813365/task.yaml +++ b/tests/regression/lp-1813365/task.yaml @@ -30,5 +30,5 @@ restore: | rm -f /tmp/logger.log execute: | - su -l -c "$(pwd)/helper" test + su -l -c "PATH=\$PATH:/usr/lib/python $(pwd)/helper" test not test -e /tmp/logger.log diff --git a/tests/regression/lp-1871652/task.yaml b/tests/regression/lp-1871652/task.yaml index 32b516b4e4c..f696c293218 100644 --- a/tests/regression/lp-1871652/task.yaml +++ b/tests/regression/lp-1871652/task.yaml @@ -15,11 +15,11 @@ details: | aware of the shutdown. # Run on a system matching the guest container. -systems: [ubuntu-18.04-64] +systems: [ubuntu-24.04-64] prepare: | "$TESTSTOOLS"/lxd-state prepare-snap - "$TESTSTOOLS"/lxd-state launch --remote ubuntu --image 18.04 --name bionic + "$TESTSTOOLS"/lxd-state launch --remote ubuntu --image 24.04 --name bionic # Install snapd inside the container and then install the core snap so that # we get re-execution logic to applies as snapd in the store is more recent