commit 313bc7f9bb623fb6889a8eb3eb2d703a9391d947 Author: garionion Date: Tue Feb 25 23:01:34 2025 +0100 feat: add main.go file diff --git a/main.go b/main.go new file mode 100644 index 0000000..024f753 --- /dev/null +++ b/main.go @@ -0,0 +1,249 @@ +package main + +import ( + "context" + "database/sql" + "embed" + "encoding/json" + "fmt" + "github.com/num30/config" + "go.uber.org/zap" + "io" + "net/http" + "strings" + + "github.com/golang-migrate/migrate/v4" + "github.com/golang-migrate/migrate/v4/database/sqlite3" + "github.com/golang-migrate/migrate/v4/source/iofs" + "github.com/labstack/echo/v4" + "github.com/labstack/echo/v4/middleware" + _ "github.com/mattn/go-sqlite3" + "golang.org/x/oauth2" +) + +//go:embed static/* +var staticFiles embed.FS + +//go:embed database/sqlite/migration/* +var sqliteMigrations embed.FS + +type OAuthConfig struct { + ClientID string `config:"client_id"` + ClientSecret string `config:"client_secret"` + AuthURL string `config:"auth_url"` + TokenURL string `config:"token_url"` + UserInfoURL string `config:"userinfo_url"` +} + +type DBConfig struct { + Driver string `config:"driver"` + Host string `config:"host"` + Port int `config:"port"` + User string `config:"user"` + Password string `config:"password"` + DBName string `config:"dbname"` +} + +type Config struct { + env string + loglevel string + Server struct { + Port int `config:"port"` + } `config:"server"` + OAuth OAuthConfig `config:"oauth"` + DB DBConfig `config:"db"` +} + +var ( + oauthConfig *oauth2.Config + cfg Config + logger *zap.Logger +) + +var db *sql.DB + +func main() { + cfgReader := config.NewConfReader("config").WithSearchDirs("./examples", ".") + cfgErr := cfgReader.Read(&cfg) + + zapcfg := zap.NewProductionConfig() + if cfg.env == "dev" { + zapcfg = zap.NewDevelopmentConfig() + } + logger, err := zapcfg.Build() + if err != nil { + panic(fmt.Sprintf("Failed to initialize logger: %v", err)) + } + defer logger.Sync() + + if cfgErr != nil { + logger.Fatal("Failed to load config", zap.Error(err)) + } + fmt.Println("test") + + logger.Info("Loaded config", zap.Any("config", cfg)) + + // Initialize OAuth2 + oauthConfig = &oauth2.Config{ + ClientID: cfg.OAuth.ClientID, + ClientSecret: cfg.OAuth.ClientSecret, + Endpoint: oauth2.Endpoint{ + AuthURL: cfg.OAuth.AuthURL, + TokenURL: cfg.OAuth.TokenURL, + }, + RedirectURL: cfg.OAuth.AuthURL, + Scopes: []string{"profile", "email"}, + } + + // Initialize database connection + dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s", + cfg.DB.User, cfg.DB.Password, cfg.DB.Host, cfg.DB.Port, cfg.DB.DBName) + db, err = sql.Open(cfg.DB.Driver, dsn) + if err != nil { + logger.Fatal("Failed to connect to database", zap.Error(err)) + } + defer db.Close() + + logger.Info("migrating sqlite database") + if err := migrateSqlite(db); err != nil { + logger.Fatal("Failed to migrate database", zap.Error(err)) + } + + // Initialize Echo + e := echo.New() + e.Use(middleware.Logger()) + e.Use(middleware.Recover()) + + // Serve static files + e.GET("/*", echo.WrapHandler(http.FileServer(http.FS(staticFiles)))) + + // OAuth2 routes + e.GET("/oauth/login", handleOAuthLogin) + e.GET("/oauth/callback", handleOAuthCallback) + + // Protected API routes + api := e.Group("/api/v1", oauthMiddleware) + api.GET("/health", func(c echo.Context) error { + user := c.Get("user").(map[string]interface{}) + logger.Info("Health check requested", zap.Any("user", user)) + return c.JSON(http.StatusOK, map[string]interface{}{ + "status": "ok", + "user": user, + }) + }) + + // Start server + addr := fmt.Sprintf(":%d", cfg.Server.Port) + logger.Info("Starting server", zap.Int("port", cfg.Server.Port)) + if err := e.Start(addr); err != nil { + logger.Fatal("Server error", zap.Error(err)) + } +} + +func migrateSqlite(db *sql.DB) error { + driver, err := sqlite3.WithInstance(db, &sqlite3.Config{}) + if err != nil { + return fmt.Errorf("failed to create migration driver: %w", err) + } + + sourceInstance, err := iofs.New(sqliteMigrations, "database/sqlite/migration") + if err != nil { + return fmt.Errorf("failed to create migration source: %w", err) + } + + m, err := migrate.NewWithInstance("iofs", sourceInstance, "sqlite3", driver) + if err != nil { + return fmt.Errorf("failed to initialize migrations: %w", err) + } + + if err := m.Up(); err != nil && err != migrate.ErrNoChange { + return fmt.Errorf("failed to run migrations: %w", err) + } + + return nil +} + +// OAuth login handler +func handleOAuthLogin(c echo.Context) error { + url := oauthConfig.AuthCodeURL("state", oauth2.AccessTypeOffline) + logger.Info("Redirecting to OAuth login", zap.String("url", url)) + return c.Redirect(http.StatusFound, url) +} + +// OAuth callback handler +func handleOAuthCallback(c echo.Context) error { + code := c.QueryParam("code") + if code == "" { + logger.Warn("Missing OAuth code in callback") + return c.JSON(http.StatusBadRequest, map[string]string{"error": "missing code"}) + } + + token, err := oauthConfig.Exchange(context.Background(), code) + if err != nil { + logger.Error("OAuth token exchange failed", zap.Error(err)) + return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to exchange token"}) + } + + logger.Info("OAuth token exchanged successfully", zap.String("access_token", token.AccessToken)) + return c.JSON(http.StatusOK, map[string]interface{}{ + "access_token": token.AccessToken, + "expires_in": token.Expiry, + }) +} + +// Middleware to enforce OAuth authentication +func oauthMiddleware(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + authHeader := c.Request().Header.Get("Authorization") + if authHeader == "" { + logger.Warn("Missing Authorization header") + return c.JSON(http.StatusUnauthorized, map[string]string{"error": "missing Authorization header"}) + } + + // Extract Bearer token + token := strings.TrimPrefix(authHeader, "Bearer ") + if token == authHeader { + logger.Warn("Invalid Authorization header format") + return c.JSON(http.StatusUnauthorized, map[string]string{"error": "invalid Authorization header"}) + } + + // Validate token and fetch user info + userInfo, err := fetchUserInfo(token) + if err != nil { + logger.Error("OAuth token validation failed", zap.Error(err)) + return c.JSON(http.StatusUnauthorized, map[string]string{"error": "invalid token"}) + } + + logger.Info("User authenticated", zap.Any("user", userInfo)) + c.Set("user", userInfo) + return next(c) + } +} + +// Fetch user info from OAuth provider +func fetchUserInfo(token string) (map[string]interface{}, error) { + req, err := http.NewRequest("GET", cfg.OAuth.UserInfoURL, nil) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+token) + + client := http.Client{} + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("OAuth validation failed: %s", string(body)) + } + + var userInfo map[string]interface{} + if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil { + return nil, err + } + + return userInfo, nil +}