From 06465c85dbb29c8ed6b8809a34ceb44fcafe2beb Mon Sep 17 00:00:00 2001 From: Baitinq Date: Tue, 28 May 2024 23:39:41 +0200 Subject: rest-api: handle payloads with an api_key --- src/rest-api/handler/db.go | 29 +++++++++++++++++++++++++---- src/rest-api/handler/handler.go | 30 +++++++++++++++++++++++++----- src/rest-api/handler/handler_test.go | 6 +++--- src/rest-api/handler/mock_db.go | 23 +++++++++++++++++++---- 4 files changed, 72 insertions(+), 16 deletions(-) (limited to 'src/rest-api/handler') diff --git a/src/rest-api/handler/db.go b/src/rest-api/handler/db.go index 0093c41..49d8e35 100644 --- a/src/rest-api/handler/db.go +++ b/src/rest-api/handler/db.go @@ -9,7 +9,8 @@ import ( //go:generate mockgen -source=$GOFILE -package=$GOPACKAGE -destination=mock_$GOFILE type DB interface { - GetLatestFileByPath(ctx context.Context, path string) (*lib.File, error) + GetLatestFileByPath(ctx context.Context, path string, user_id string) (*lib.File, error) + GetUserIDByAPIKey(ctx context.Context, apiKey string) (string, error) } type DBImpl struct { @@ -22,16 +23,36 @@ func NewDB(db *sqlx.DB) DB { return &DBImpl{db: db} } -func (db DBImpl) GetLatestFileByPath(ctx context.Context, path string) (*lib.File, error) { +func (db DBImpl) GetLatestFileByPath(ctx context.Context, path string, user_id string) (*lib.File, error) { var file lib.File err := db.db.GetContext(ctx, &file, ` SELECT * FROM private.file - WHERE absolute_path = $1 + WHERE + user_id = $1 + AND absolute_path = $2 ORDER BY timestamp DESC LIMIT 1 - `, path) + `, user_id, path) if err != nil { return nil, err } return &file, nil } + +// TODO: Add test +func (db DBImpl) GetUserIDByAPIKey(ctx context.Context, apiKey string) (string, error) { + if len(apiKey) != 44 { + return "", nil + } + + var userID string + err := db.db.GetContext(ctx, &userID, ` + SELECT id FROM private.api_keys + WHERE api_key = $1 + LIMIT 1 + `, apiKey) + if err != nil { + return "", err + } + return userID, nil +} diff --git a/src/rest-api/handler/handler.go b/src/rest-api/handler/handler.go index 558e773..4b9a426 100644 --- a/src/rest-api/handler/handler.go +++ b/src/rest-api/handler/handler.go @@ -27,17 +27,33 @@ func NewHandler(db *sqlx.DB, kafka_writer *kafka.Writer) Handler { } func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + api_key := r.Header.Get("API_KEY") + + log.Println("API KEY: ", api_key) + + user_id, err := h.db.GetUserIDByAPIKey(r.Context(), api_key) + if err != nil { + http.Error(w, fmt.Sprintf("Internal server error: %s", err), http.StatusInternalServerError) + return + } + if user_id == "" { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + + log.Println("User ID: ", user_id) + switch r.Method { case http.MethodGet: - h.handleGet(w, r) + h.handleGet(w, r, user_id) case http.MethodPost: - h.handlePost(w, r) + h.handlePost(w, r, user_id) default: http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } } -func (h Handler) handleGet(w http.ResponseWriter, r *http.Request) { +func (h Handler) handleGet(w http.ResponseWriter, r *http.Request, user_id string) { _, filePath, ok := strings.Cut(r.URL.Path, "/file/") if !ok { http.Error(w, "Invalid file path", http.StatusBadRequest) @@ -48,7 +64,7 @@ func (h Handler) handleGet(w http.ResponseWriter, r *http.Request) { ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second) defer cancel() - file, err := h.db.GetLatestFileByPath(ctx, filePath) + file, err := h.db.GetLatestFileByPath(ctx, filePath, user_id) if err != nil { http.Error(w, fmt.Sprintf("Internal server error: %s", err), http.StatusInternalServerError) return @@ -57,7 +73,7 @@ func (h Handler) handleGet(w http.ResponseWriter, r *http.Request) { fmt.Fprintln(w, "File: ", file) } -func (h Handler) handlePost(w http.ResponseWriter, r *http.Request) { +func (h Handler) handlePost(w http.ResponseWriter, r *http.Request, user_id string) { bytes, err := io.ReadAll(io.Reader(r.Body)) if err != nil { log.Fatal(err) @@ -69,6 +85,10 @@ func (h Handler) handlePost(w http.ResponseWriter, r *http.Request) { err = h.kafka_writer.WriteMessages(ctx, kafka.Message{ Key: []byte("key"), //TODO:This routes to a partition. We should probably route by agent UUID TODO: wont this negate having multiple topics Value: bytes, + Headers: []kafka.Header{{ + Key: "user_id", + Value: []byte(user_id), + }}, }) if err != nil { log.Fatal(err) diff --git a/src/rest-api/handler/handler_test.go b/src/rest-api/handler/handler_test.go index 94b58e4..4709959 100644 --- a/src/rest-api/handler/handler_test.go +++ b/src/rest-api/handler/handler_test.go @@ -4,7 +4,6 @@ import ( "fmt" "net/http" "net/http/httptest" - "strings" "testing" "github.com/Baitinq/fs-tracer-backend/lib" @@ -20,11 +19,12 @@ func TestHandleGet(t *testing.T) { handler := Handler{db: db} file := &lib.File{ + User_id: "USER_ID", Absolute_path: "/tmp/file.txt", } - db.EXPECT().GetLatestFileByPath(gomock.Any(), "/tmp/file.txt").Return(file, nil) + db.EXPECT().GetLatestFileByPath(gomock.Any(), "/tmp/file.txt", "USER_ID").Return(file, nil) - handler.handleGet(recorder, httptest.NewRequest(http.MethodGet, "/file/%2ftmp%2Ffile.txt", nil)) + handler.handleGet(recorder, httptest.NewRequest(http.MethodGet, "/file/%2ftmp%2Ffile.txt", nil), "USER_ID") require.Equal(t, http.StatusOK, recorder.Code) require.Equal(t, fmt.Sprintln("File: ", file), recorder.Body.String()) diff --git a/src/rest-api/handler/mock_db.go b/src/rest-api/handler/mock_db.go index 2d51a8f..542fd62 100644 --- a/src/rest-api/handler/mock_db.go +++ b/src/rest-api/handler/mock_db.go @@ -41,16 +41,31 @@ func (m *MockDB) EXPECT() *MockDBMockRecorder { } // GetLatestFileByPath mocks base method. -func (m *MockDB) GetLatestFileByPath(ctx context.Context, path string) (*lib.File, error) { +func (m *MockDB) GetLatestFileByPath(ctx context.Context, path, user_id string) (*lib.File, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetLatestFileByPath", ctx, path) + ret := m.ctrl.Call(m, "GetLatestFileByPath", ctx, path, user_id) ret0, _ := ret[0].(*lib.File) ret1, _ := ret[1].(error) return ret0, ret1 } // GetLatestFileByPath indicates an expected call of GetLatestFileByPath. -func (mr *MockDBMockRecorder) GetLatestFileByPath(ctx, path any) *gomock.Call { +func (mr *MockDBMockRecorder) GetLatestFileByPath(ctx, path, user_id any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLatestFileByPath", reflect.TypeOf((*MockDB)(nil).GetLatestFileByPath), ctx, path) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLatestFileByPath", reflect.TypeOf((*MockDB)(nil).GetLatestFileByPath), ctx, path, user_id) +} + +// GetUserIDByAPIKey mocks base method. +func (m *MockDB) GetUserIDByAPIKey(ctx context.Context, apiKey string) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetUserIDByAPIKey", ctx, apiKey) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetUserIDByAPIKey indicates an expected call of GetUserIDByAPIKey. +func (mr *MockDBMockRecorder) GetUserIDByAPIKey(ctx, apiKey any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserIDByAPIKey", reflect.TypeOf((*MockDB)(nil).GetUserIDByAPIKey), ctx, apiKey) } -- cgit 1.4.1