Skip to content

Commit ad3d12d

Browse files
committed
test: add tests (unit/api) for additional api keys
1 parent fbfd94d commit ad3d12d

27 files changed

+584
-40
lines changed

main.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,6 @@ func main() {
122122
os.Exit(0)
123123
}
124124
config = conf.Load(*configFlag, version)
125-
slog.Info("loaded configuration", "configFile", *configFlag)
126125

127126
// Configure Swagger docs
128127
docs.SwaggerInfo.BasePath = config.Server.BasePath + "/api"

middlewares/authenticate.go

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,16 +34,16 @@ type AuthenticateMiddleware struct {
3434
optionalForMethods []string
3535
redirectTarget string // optional
3636
redirectErrorMessage string // optional
37-
onlyRWApiKey bool
37+
requireFullAccessKey bool // true only for heartbeat routes
3838
}
3939

4040
func NewAuthenticateMiddleware(userService services.IUserService) *AuthenticateMiddleware {
4141
return &AuthenticateMiddleware{
42-
config: conf.Get(),
43-
userSrvc: userService,
44-
optionalForPaths: []string{},
45-
optionalForMethods: []string{},
46-
onlyRWApiKey: false,
42+
config: conf.Get(),
43+
userSrvc: userService,
44+
optionalForPaths: []string{},
45+
optionalForMethods: []string{},
46+
requireFullAccessKey: false,
4747
}
4848
}
4949

@@ -67,8 +67,8 @@ func (m *AuthenticateMiddleware) WithRedirectErrorMessage(message string) *Authe
6767
return m
6868
}
6969

70-
func (m *AuthenticateMiddleware) WithOnlyRWApiKey(onlyRW bool) *AuthenticateMiddleware {
71-
m.onlyRWApiKey = onlyRW
70+
func (m *AuthenticateMiddleware) WithFullAccessOnly(readOnly bool) *AuthenticateMiddleware {
71+
m.requireFullAccessKey = readOnly
7272
return m
7373
}
7474

@@ -145,7 +145,7 @@ func (m *AuthenticateMiddleware) tryGetUserByApiKeyHeader(r *http.Request) (*mod
145145

146146
var user *models.User
147147
userKey := strings.TrimSpace(key)
148-
user, err = m.userSrvc.GetUserByKey(userKey, m.onlyRWApiKey)
148+
user, err = m.userSrvc.GetUserByKey(userKey, m.requireFullAccessKey)
149149
if err != nil {
150150
return nil, err
151151
}
@@ -159,7 +159,7 @@ func (m *AuthenticateMiddleware) tryGetUserByApiKeyQuery(r *http.Request) (*mode
159159
if userKey == "" {
160160
return nil, errEmptyKey
161161
}
162-
user, err := m.userSrvc.GetUserByKey(userKey, m.onlyRWApiKey)
162+
user, err := m.userSrvc.GetUserByKey(userKey, m.requireFullAccessKey)
163163
if err != nil {
164164
return nil, err
165165
}

middlewares/authenticate_test.go

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,10 @@ import (
2323

2424
func TestAuthenticateMiddleware_tryGetUserByApiKeyHeader_Success(t *testing.T) {
2525
testApiKey := "z5uig69cn9ut93n"
26-
readOnlyApiKey := false
2726
testToken := base64.StdEncoding.EncodeToString([]byte(testApiKey))
2827
testUser := &models.User{ApiKey: testApiKey}
28+
// In the case of the API Key from User Model - it's always full access
29+
testApiKeyRequireFullAccess := false
2930

3031
mockRequest := &http.Request{
3132
Header: http.Header{
@@ -34,7 +35,7 @@ func TestAuthenticateMiddleware_tryGetUserByApiKeyHeader_Success(t *testing.T) {
3435
}
3536

3637
userServiceMock := new(mocks.UserServiceMock)
37-
userServiceMock.On("GetUserByKey", testApiKey, readOnlyApiKey).Return(testUser, nil)
38+
userServiceMock.On("GetUserByKey", testApiKey, testApiKeyRequireFullAccess).Return(testUser, nil)
3839

3940
sut := NewAuthenticateMiddleware(userServiceMock)
4041

@@ -58,6 +59,29 @@ func TestAuthenticateMiddleware_tryGetUserByApiKeyHeader_Invalid(t *testing.T) {
5859
userServiceMock := new(mocks.UserServiceMock)
5960

6061
sut := NewAuthenticateMiddleware(userServiceMock)
62+
sut.WithFullAccessOnly(false)
63+
64+
result, err := sut.tryGetUserByApiKeyHeader(mockRequest)
65+
66+
assert.Error(t, err)
67+
assert.Nil(t, result)
68+
}
69+
70+
func TestAuthenticateMiddleware_tryGetUserByApiKeyHeaderWithReadOnlyKey_Invalid(t *testing.T) {
71+
testApiKey := "read-only-additional-key"
72+
testToken := base64.StdEncoding.EncodeToString([]byte(testApiKey))
73+
74+
mockRequest := &http.Request{
75+
Header: http.Header{
76+
"Authorization": []string{fmt.Sprintf("Basic %s", testToken)},
77+
},
78+
}
79+
80+
userServiceMock := new(mocks.UserServiceMock)
81+
userServiceMock.On("GetUserByKey", testApiKey, true).Return(nil, errors.New("forbidden: requires full access"))
82+
83+
sut := NewAuthenticateMiddleware(userServiceMock)
84+
sut.WithFullAccessOnly(true)
6185

6286
result, err := sut.tryGetUserByApiKeyHeader(mockRequest)
6387

@@ -67,8 +91,9 @@ func TestAuthenticateMiddleware_tryGetUserByApiKeyHeader_Invalid(t *testing.T) {
6791

6892
func TestAuthenticateMiddleware_tryGetUserByApiKeyQuery_Success(t *testing.T) {
6993
testApiKey := "z5uig69cn9ut93n"
70-
readOnlyApiKey := false
7194
testUser := &models.User{ApiKey: testApiKey}
95+
// In the case of the API Key from User Model - it's always full access
96+
testApiKeyRequireFullAccess := true
7297

7398
params := url.Values{}
7499
params.Add("api_key", testApiKey)
@@ -79,9 +104,10 @@ func TestAuthenticateMiddleware_tryGetUserByApiKeyQuery_Success(t *testing.T) {
79104
}
80105

81106
userServiceMock := new(mocks.UserServiceMock)
82-
userServiceMock.On("GetUserByKey", testApiKey, readOnlyApiKey).Return(testUser, nil)
107+
userServiceMock.On("GetUserByKey", testApiKey, testApiKeyRequireFullAccess).Return(testUser, nil)
83108

84109
sut := NewAuthenticateMiddleware(userServiceMock)
110+
sut.WithFullAccessOnly(true)
85111

86112
result, err := sut.tryGetUserByApiKeyQuery(mockRequest)
87113

mocks/api_key_service.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
package mocks
2+
3+
import (
4+
"github.com/muety/wakapi/models"
5+
"github.com/stretchr/testify/mock"
6+
)
7+
8+
type MockApiKeyService struct {
9+
mock.Mock
10+
}
11+
12+
func (m *MockApiKeyService) GetByApiKey(apiKey string, requireFullAccessKey bool) (*models.ApiKey, error) {
13+
args := m.Called(apiKey, requireFullAccessKey)
14+
if args.Get(0) == nil {
15+
return nil, args.Error(1)
16+
}
17+
return args.Get(0).(*models.ApiKey), args.Error(1)
18+
}
19+
20+
func (m *MockApiKeyService) GetByUser(userID string) ([]*models.ApiKey, error) {
21+
args := m.Called(userID)
22+
if args.Get(0) == nil {
23+
return nil, args.Error(1)
24+
}
25+
return args.Get(0).([]*models.ApiKey), args.Error(1)
26+
}
27+
28+
func (m *MockApiKeyService) Create(apiKey *models.ApiKey) (*models.ApiKey, error) {
29+
args := m.Called(apiKey)
30+
if args.Get(0) == nil {
31+
return nil, args.Error(1)
32+
}
33+
return args.Get(0).(*models.ApiKey), args.Error(1)
34+
}
35+
36+
func (m *MockApiKeyService) Delete(apiKey *models.ApiKey) error {
37+
args := m.Called(apiKey)
38+
return args.Error(0)
39+
}

mocks/mail_service.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
package mocks
2+
3+
import (
4+
"time"
5+
6+
"github.com/muety/wakapi/models"
7+
"github.com/stretchr/testify/mock"
8+
)
9+
10+
type MailServiceMock struct {
11+
mock.Mock
12+
}
13+
14+
func (m *MailServiceMock) SendPasswordReset(user *models.User, resetLink string) error {
15+
args := m.Called(user, resetLink)
16+
return args.Error(0)
17+
}
18+
19+
func (m *MailServiceMock) SendWakatimeFailureNotification(user *models.User, numFailures int) error {
20+
args := m.Called(user, numFailures)
21+
return args.Error(0)
22+
}
23+
24+
func (m *MailServiceMock) SendImportNotification(user *models.User, duration time.Duration, numHeartbeats int) error {
25+
args := m.Called(user, duration, numHeartbeats)
26+
return args.Error(0)
27+
}
28+
29+
func (m *MailServiceMock) SendReport(user *models.User, report *models.Report) error {
30+
args := m.Called(user, report)
31+
return args.Error(0)
32+
}
33+
34+
func (m *MailServiceMock) SendSubscriptionNotification(user *models.User, hasExpired bool) error {
35+
args := m.Called(user, hasExpired)
36+
return args.Error(0)
37+
}

mocks/user_repository.go

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
package mocks
2+
3+
import (
4+
"time"
5+
6+
"github.com/muety/wakapi/models"
7+
"github.com/stretchr/testify/mock"
8+
"gorm.io/gorm"
9+
)
10+
11+
type UserRepositoryMock struct {
12+
BaseRepositoryMock
13+
mock.Mock
14+
}
15+
16+
func (m *UserRepositoryMock) FindOne(user models.User) (*models.User, error) {
17+
args := m.Called(user)
18+
if args.Get(0) == nil {
19+
return nil, args.Error(1)
20+
}
21+
return args.Get(0).(*models.User), args.Error(1)
22+
}
23+
24+
func (m *UserRepositoryMock) GetByIds(userIds []string) ([]*models.User, error) {
25+
args := m.Called(userIds)
26+
if args.Get(0) == nil {
27+
return nil, args.Error(1)
28+
}
29+
return args.Get(0).([]*models.User), args.Error(1)
30+
}
31+
32+
func (m *UserRepositoryMock) GetAll() ([]*models.User, error) {
33+
args := m.Called()
34+
if args.Get(0) == nil {
35+
return nil, args.Error(1)
36+
}
37+
return args.Get(0).([]*models.User), args.Error(1)
38+
}
39+
40+
func (m *UserRepositoryMock) GetMany(ids []string) ([]*models.User, error) {
41+
args := m.Called(ids)
42+
if args.Get(0) == nil {
43+
return nil, args.Error(1)
44+
}
45+
return args.Get(0).([]*models.User), args.Error(1)
46+
}
47+
48+
func (m *UserRepositoryMock) GetAllByReports(reportsEnabled bool) ([]*models.User, error) {
49+
args := m.Called(reportsEnabled)
50+
if args.Get(0) == nil {
51+
return nil, args.Error(1)
52+
}
53+
return args.Get(0).([]*models.User), args.Error(1)
54+
}
55+
56+
func (m *UserRepositoryMock) GetAllByLeaderboard(leaderboardEnabled bool) ([]*models.User, error) {
57+
args := m.Called(leaderboardEnabled)
58+
if args.Get(0) == nil {
59+
return nil, args.Error(1)
60+
}
61+
return args.Get(0).([]*models.User), args.Error(1)
62+
}
63+
64+
func (m *UserRepositoryMock) GetByLoggedInBefore(t time.Time) ([]*models.User, error) {
65+
args := m.Called(t)
66+
if args.Get(0) == nil {
67+
return nil, args.Error(1)
68+
}
69+
return args.Get(0).([]*models.User), args.Error(1)
70+
}
71+
72+
func (m *UserRepositoryMock) GetByLoggedInAfter(t time.Time) ([]*models.User, error) {
73+
args := m.Called(t)
74+
if args.Get(0) == nil {
75+
return nil, args.Error(1)
76+
}
77+
return args.Get(0).([]*models.User), args.Error(1)
78+
}
79+
80+
func (m *UserRepositoryMock) GetByLastActiveAfter(t time.Time) ([]*models.User, error) {
81+
args := m.Called(t)
82+
if args.Get(0) == nil {
83+
return nil, args.Error(1)
84+
}
85+
return args.Get(0).([]*models.User), args.Error(1)
86+
}
87+
88+
func (m *UserRepositoryMock) Count() (int64, error) {
89+
args := m.Called()
90+
return args.Get(0).(int64), args.Error(1)
91+
}
92+
93+
func (m *UserRepositoryMock) InsertOrGet(user *models.User) (*models.User, bool, error) {
94+
args := m.Called(user)
95+
if args.Get(0) == nil {
96+
return nil, args.Bool(1), args.Error(2)
97+
}
98+
return args.Get(0).(*models.User), args.Bool(1), args.Error(2)
99+
}
100+
101+
func (m *UserRepositoryMock) Update(user *models.User) (*models.User, error) {
102+
args := m.Called(user)
103+
if args.Get(0) == nil {
104+
return nil, args.Error(1)
105+
}
106+
return args.Get(0).(*models.User), args.Error(1)
107+
}
108+
109+
func (m *UserRepositoryMock) UpdateField(user *models.User, key string, value interface{}) (*models.User, error) {
110+
args := m.Called(user, key, value)
111+
if args.Get(0) == nil {
112+
return nil, args.Error(1)
113+
}
114+
return args.Get(0).(*models.User), args.Error(1)
115+
}
116+
117+
func (m *UserRepositoryMock) Delete(user *models.User) error {
118+
args := m.Called(user)
119+
return args.Error(0)
120+
}
121+
122+
func (m *UserRepositoryMock) DeleteTx(user *models.User, tx *gorm.DB) error {
123+
args := m.Called(user, tx)
124+
return args.Error(0)
125+
}

repositories/api_key.go

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,16 @@ func (r *ApiKeyRepository) GetAll() ([]*models.ApiKey, error) {
2626
return keys, nil
2727
}
2828

29-
func (r *ApiKeyRepository) GetByApiKey(apiKey string, readOnly bool) (*models.ApiKey, error) {
29+
func (r *ApiKeyRepository) GetByApiKey(apiKey string, requireFullAccessKey bool) (*models.ApiKey, error) {
3030
key := &models.ApiKey{}
31-
if err := r.db.Where(&models.ApiKey{ApiKey: apiKey, ReadOnly: readOnly}).First(key).Error; err != nil {
32-
return key, err
31+
32+
query := r.db.Preload("User").Where("api_key = ?", apiKey)
33+
if requireFullAccessKey {
34+
query = query.Where("read_only = ?", false)
35+
}
36+
37+
if err := query.First(key).Error; err != nil {
38+
return nil, err
3339
}
3440
return key, nil
3541
}

routes/api/heartbeat.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ func NewHeartbeatApiHandler(userService services.IUserService, heartbeatService
3737
func (h *HeartbeatApiHandler) RegisterRoutes(router chi.Router) {
3838
router.Group(func(r chi.Router) {
3939
r.Use(
40-
middlewares.NewAuthenticateMiddleware(h.userSrvc).WithOptionalForMethods(http.MethodOptions).WithOnlyRWApiKey(true).Handler,
40+
middlewares.NewAuthenticateMiddleware(h.userSrvc).WithOptionalForMethods(http.MethodOptions).WithFullAccessOnly(true).Handler,
4141
customMiddleware.NewWakatimeRelayMiddleware().Handler,
4242
)
4343
// see https://github.com/muety/wakapi/issues/203

routes/login.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ import (
1313
"github.com/duke-git/lancet/v2/slice"
1414
"github.com/go-chi/chi/v5"
1515
"github.com/go-chi/httprate"
16-
1716
conf "github.com/muety/wakapi/config"
1817
"github.com/muety/wakapi/middlewares"
1918
"github.com/muety/wakapi/models"

0 commit comments

Comments
 (0)