feat: add main.go file
This commit is contained in:
commit
313bc7f9bb
1 changed files with 249 additions and 0 deletions
249
main.go
Normal file
249
main.go
Normal file
|
@ -0,0 +1,249 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"embed"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/num30/config"
|
||||
"go.uber.org/zap"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
"github.com/golang-migrate/migrate/v4/database/sqlite3"
|
||||
"github.com/golang-migrate/migrate/v4/source/iofs"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/labstack/echo/v4/middleware"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
//go:embed static/*
|
||||
var staticFiles embed.FS
|
||||
|
||||
//go:embed database/sqlite/migration/*
|
||||
var sqliteMigrations embed.FS
|
||||
|
||||
type OAuthConfig struct {
|
||||
ClientID string `config:"client_id"`
|
||||
ClientSecret string `config:"client_secret"`
|
||||
AuthURL string `config:"auth_url"`
|
||||
TokenURL string `config:"token_url"`
|
||||
UserInfoURL string `config:"userinfo_url"`
|
||||
}
|
||||
|
||||
type DBConfig struct {
|
||||
Driver string `config:"driver"`
|
||||
Host string `config:"host"`
|
||||
Port int `config:"port"`
|
||||
User string `config:"user"`
|
||||
Password string `config:"password"`
|
||||
DBName string `config:"dbname"`
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
env string
|
||||
loglevel string
|
||||
Server struct {
|
||||
Port int `config:"port"`
|
||||
} `config:"server"`
|
||||
OAuth OAuthConfig `config:"oauth"`
|
||||
DB DBConfig `config:"db"`
|
||||
}
|
||||
|
||||
var (
|
||||
oauthConfig *oauth2.Config
|
||||
cfg Config
|
||||
logger *zap.Logger
|
||||
)
|
||||
|
||||
var db *sql.DB
|
||||
|
||||
func main() {
|
||||
cfgReader := config.NewConfReader("config").WithSearchDirs("./examples", ".")
|
||||
cfgErr := cfgReader.Read(&cfg)
|
||||
|
||||
zapcfg := zap.NewProductionConfig()
|
||||
if cfg.env == "dev" {
|
||||
zapcfg = zap.NewDevelopmentConfig()
|
||||
}
|
||||
logger, err := zapcfg.Build()
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to initialize logger: %v", err))
|
||||
}
|
||||
defer logger.Sync()
|
||||
|
||||
if cfgErr != nil {
|
||||
logger.Fatal("Failed to load config", zap.Error(err))
|
||||
}
|
||||
fmt.Println("test")
|
||||
|
||||
logger.Info("Loaded config", zap.Any("config", cfg))
|
||||
|
||||
// Initialize OAuth2
|
||||
oauthConfig = &oauth2.Config{
|
||||
ClientID: cfg.OAuth.ClientID,
|
||||
ClientSecret: cfg.OAuth.ClientSecret,
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: cfg.OAuth.AuthURL,
|
||||
TokenURL: cfg.OAuth.TokenURL,
|
||||
},
|
||||
RedirectURL: cfg.OAuth.AuthURL,
|
||||
Scopes: []string{"profile", "email"},
|
||||
}
|
||||
|
||||
// Initialize database connection
|
||||
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s",
|
||||
cfg.DB.User, cfg.DB.Password, cfg.DB.Host, cfg.DB.Port, cfg.DB.DBName)
|
||||
db, err = sql.Open(cfg.DB.Driver, dsn)
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to connect to database", zap.Error(err))
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
logger.Info("migrating sqlite database")
|
||||
if err := migrateSqlite(db); err != nil {
|
||||
logger.Fatal("Failed to migrate database", zap.Error(err))
|
||||
}
|
||||
|
||||
// Initialize Echo
|
||||
e := echo.New()
|
||||
e.Use(middleware.Logger())
|
||||
e.Use(middleware.Recover())
|
||||
|
||||
// Serve static files
|
||||
e.GET("/*", echo.WrapHandler(http.FileServer(http.FS(staticFiles))))
|
||||
|
||||
// OAuth2 routes
|
||||
e.GET("/oauth/login", handleOAuthLogin)
|
||||
e.GET("/oauth/callback", handleOAuthCallback)
|
||||
|
||||
// Protected API routes
|
||||
api := e.Group("/api/v1", oauthMiddleware)
|
||||
api.GET("/health", func(c echo.Context) error {
|
||||
user := c.Get("user").(map[string]interface{})
|
||||
logger.Info("Health check requested", zap.Any("user", user))
|
||||
return c.JSON(http.StatusOK, map[string]interface{}{
|
||||
"status": "ok",
|
||||
"user": user,
|
||||
})
|
||||
})
|
||||
|
||||
// Start server
|
||||
addr := fmt.Sprintf(":%d", cfg.Server.Port)
|
||||
logger.Info("Starting server", zap.Int("port", cfg.Server.Port))
|
||||
if err := e.Start(addr); err != nil {
|
||||
logger.Fatal("Server error", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
func migrateSqlite(db *sql.DB) error {
|
||||
driver, err := sqlite3.WithInstance(db, &sqlite3.Config{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create migration driver: %w", err)
|
||||
}
|
||||
|
||||
sourceInstance, err := iofs.New(sqliteMigrations, "database/sqlite/migration")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create migration source: %w", err)
|
||||
}
|
||||
|
||||
m, err := migrate.NewWithInstance("iofs", sourceInstance, "sqlite3", driver)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize migrations: %w", err)
|
||||
}
|
||||
|
||||
if err := m.Up(); err != nil && err != migrate.ErrNoChange {
|
||||
return fmt.Errorf("failed to run migrations: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// OAuth login handler
|
||||
func handleOAuthLogin(c echo.Context) error {
|
||||
url := oauthConfig.AuthCodeURL("state", oauth2.AccessTypeOffline)
|
||||
logger.Info("Redirecting to OAuth login", zap.String("url", url))
|
||||
return c.Redirect(http.StatusFound, url)
|
||||
}
|
||||
|
||||
// OAuth callback handler
|
||||
func handleOAuthCallback(c echo.Context) error {
|
||||
code := c.QueryParam("code")
|
||||
if code == "" {
|
||||
logger.Warn("Missing OAuth code in callback")
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": "missing code"})
|
||||
}
|
||||
|
||||
token, err := oauthConfig.Exchange(context.Background(), code)
|
||||
if err != nil {
|
||||
logger.Error("OAuth token exchange failed", zap.Error(err))
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to exchange token"})
|
||||
}
|
||||
|
||||
logger.Info("OAuth token exchanged successfully", zap.String("access_token", token.AccessToken))
|
||||
return c.JSON(http.StatusOK, map[string]interface{}{
|
||||
"access_token": token.AccessToken,
|
||||
"expires_in": token.Expiry,
|
||||
})
|
||||
}
|
||||
|
||||
// Middleware to enforce OAuth authentication
|
||||
func oauthMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
authHeader := c.Request().Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
logger.Warn("Missing Authorization header")
|
||||
return c.JSON(http.StatusUnauthorized, map[string]string{"error": "missing Authorization header"})
|
||||
}
|
||||
|
||||
// Extract Bearer token
|
||||
token := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
if token == authHeader {
|
||||
logger.Warn("Invalid Authorization header format")
|
||||
return c.JSON(http.StatusUnauthorized, map[string]string{"error": "invalid Authorization header"})
|
||||
}
|
||||
|
||||
// Validate token and fetch user info
|
||||
userInfo, err := fetchUserInfo(token)
|
||||
if err != nil {
|
||||
logger.Error("OAuth token validation failed", zap.Error(err))
|
||||
return c.JSON(http.StatusUnauthorized, map[string]string{"error": "invalid token"})
|
||||
}
|
||||
|
||||
logger.Info("User authenticated", zap.Any("user", userInfo))
|
||||
c.Set("user", userInfo)
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch user info from OAuth provider
|
||||
func fetchUserInfo(token string) (map[string]interface{}, error) {
|
||||
req, err := http.NewRequest("GET", cfg.OAuth.UserInfoURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
|
||||
client := http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("OAuth validation failed: %s", string(body))
|
||||
}
|
||||
|
||||
var userInfo map[string]interface{}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return userInfo, nil
|
||||
}
|
Loading…
Add table
Reference in a new issue