package main import ( "context" "database/sql" "embed" "encoding/json" "fmt" "github.com/num30/config" "github.com/pressly/goose/v3" "go.uber.org/zap" "io" "net/http" "strings" "github.com/labstack/echo/v4" "github.com/labstack/echo/v4/middleware" _ "github.com/mattn/go-sqlite3" "golang.org/x/oauth2" ) //go:embed static/* var staticFiles embed.FS //go:embed database/sqlite/migration/*.sql var sqliteMigrations embed.FS type OAuthConfig struct { ClientID string `config:"client_id"` ClientSecret string `config:"client_secret"` AuthURL string `config:"auth_url"` TokenURL string `config:"token_url"` UserInfoURL string `config:"userinfo_url"` } type DBConfig struct { Driver string `config:"driver"` Host string `config:"host"` Port int `config:"port"` User string `config:"user"` Password string `config:"password"` DBName string `config:"dbname"` } type Config struct { Env string `config:"env"` Loglevel string `config:"loglevel"` Server struct { Port int `config:"port"` } `config:"server"` OAuth OAuthConfig `config:"oauth"` DB DBConfig `config:"db"` } var ( oauthConfig *oauth2.Config cfg Config logger *zap.Logger ) var db *sql.DB func main() { // Use the global cfg variable instead of declaring a new one cfgReader := config.NewConfReader("config").WithSearchDirs("./", ".") cfgErr := cfgReader.Read(&cfg) zapcfg := zap.NewProductionConfig() if cfg.Env == "dev" { zapcfg = zap.NewDevelopmentConfig() } logger, err := zapcfg.Build() if err != nil { panic(fmt.Sprintf("Failed to initialize logger: %v", err)) } defer logger.Sync() if cfgErr != nil { logger.Fatal("Failed to load config", zap.Error(cfgErr)) } logger.Info("Loaded config", zap.Any("config", cfg)) // Initialize OAuth2 oauthConfig = &oauth2.Config{ ClientID: cfg.OAuth.ClientID, ClientSecret: cfg.OAuth.ClientSecret, Endpoint: oauth2.Endpoint{ AuthURL: cfg.OAuth.AuthURL, TokenURL: cfg.OAuth.TokenURL, }, RedirectURL: fmt.Sprintf("http://localhost:%d/oauth/callback", cfg.Server.Port), Scopes: []string{"profile", "email"}, } fmt.Println(oauthConfig) // Initialize database connection var dsn string if cfg.DB.Driver == "sqlite3" { // For SQLite, the DSN is just the path to the database file dsn = cfg.DB.DBName } else { // For other databases like MySQL, use the standard connection string dsn = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s", cfg.DB.User, cfg.DB.Password, cfg.DB.Host, cfg.DB.Port, cfg.DB.DBName) } db, err = sql.Open(cfg.DB.Driver, dsn) if err != nil { logger.Fatal("Failed to connect to database", zap.Error(err)) } defer db.Close() // Test the connection if err = db.Ping(); err != nil { logger.Fatal("Failed to ping database", zap.Error(err)) } logger.Info("Successfully connected to database", zap.String("driver", cfg.DB.Driver)) if cfg.DB.Driver == "sqlite3" { logger.Info("migrating sqlite database") if err := migrateSqlite(db); err != nil { logger.Fatal("Failed to migrate database", zap.Error(err)) } } // Initialize Echo e := echo.New() e.Use(middleware.Logger()) //e.Use(middleware.Recover()) e.Use(middleware.CORS()) // OAuth2 routes e.GET("/oauth/login", handleOAuthLogin) e.GET("/oauth/callback", handleOAuthCallback) // Protected API routes api := e.Group("/api/v1", oauthMiddleware) api.GET("/health", func(c echo.Context) error { user := c.Get("user").(map[string]interface{}) logger.Info("Health check requested", zap.Any("user", user)) return c.JSON(http.StatusOK, map[string]interface{}{ "status": "ok", "user": user, }) }) // Serve static files - must be after API routes e.GET("/*", echo.WrapHandler(http.FileServer(http.FS(staticFiles)))) // Start server addr := fmt.Sprintf(":%d", cfg.Server.Port) logger.Info("Starting server", zap.Int("port", cfg.Server.Port)) if err := e.Start(addr); err != nil { logger.Fatal("Server error", zap.Error(err)) } } func migrateSqlite(db *sql.DB) error { goose.SetBaseFS(sqliteMigrations) if err := goose.SetDialect("sqlite3"); err != nil { return fmt.Errorf("failed to set db type for migration: %w", err) } if err := goose.Up(db, "database/sqlite/migration"); err != nil { return fmt.Errorf("failed to apply db migrations: %w", err) } return nil } // OAuth login handler func handleOAuthLogin(c echo.Context) error { url := oauthConfig.AuthCodeURL("state", oauth2.AccessTypeOffline) logger.Info("Redirecting to OAuth login", zap.String("url", url)) return c.Redirect(http.StatusFound, url) } // OAuth callback handler func handleOAuthCallback(c echo.Context) error { code := c.QueryParam("code") if code == "" { logger.Warn("Missing OAuth code in callback") return c.JSON(http.StatusBadRequest, map[string]string{"error": "missing code"}) } token, err := oauthConfig.Exchange(context.Background(), code) if err != nil { logger.Error("OAuth token exchange failed", zap.Error(err)) return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to exchange token"}) } logger.Info("OAuth token exchanged successfully", zap.String("access_token", token.AccessToken)) // Return HTML that stores the token and redirects to the main app html := fmt.Sprintf(` Authentication Successful

Authentication successful. Redirecting...

`, token.AccessToken) return c.HTML(http.StatusOK, html) } // Middleware to enforce OAuth authentication func oauthMiddleware(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { authHeader := c.Request().Header.Get("Authorization") if authHeader == "" { logger.Warn("Missing Authorization header") return c.JSON(http.StatusUnauthorized, map[string]string{"error": "missing Authorization header"}) } // Extract Bearer token token := strings.TrimPrefix(authHeader, "Bearer ") if token == authHeader { logger.Warn("Invalid Authorization header format") return c.JSON(http.StatusUnauthorized, map[string]string{"error": "invalid Authorization header"}) } // Validate token and fetch user info userInfo, err := fetchUserInfo(token) if err != nil { logger.Error("OAuth token validation failed", zap.Error(err)) return c.JSON(http.StatusUnauthorized, map[string]string{"error": "invalid token"}) } logger.Info("User authenticated", zap.Any("user", userInfo)) c.Set("user", userInfo) return next(c) } } // Fetch user info from OAuth provider func fetchUserInfo(token string) (map[string]interface{}, error) { req, err := http.NewRequest("GET", cfg.OAuth.UserInfoURL, nil) if err != nil { return nil, err } req.Header.Set("Authorization", "Bearer "+token) client := http.Client{} resp, err := client.Do(req) if err != nil { return nil, err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) return nil, fmt.Errorf("OAuth validation failed: %s", string(body)) } var userInfo map[string]interface{} if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil { return nil, err } return userInfo, nil }