Use filesystem based kv store instead of sqlite
authorr <r@freesoftwareextremist.com>
Tue, 17 Dec 2019 20:17:25 +0000 (20:17 +0000)
committerr <r@freesoftwareextremist.com>
Tue, 17 Dec 2019 20:17:25 +0000 (20:17 +0000)
go.mod
go.mod.old [new file with mode: 0644]
go.sum
kv/kv.go [new file with mode: 0644]
main.go
model/app.go
model/session.go
repository/appRepository.go
repository/sessionRepository.go
service/auth.go
service/service.go

diff --git a/go.mod b/go.mod
index de1ba8901bd600e5bb5b5f334fcd082050275747..0bebfe16bf45d195beb617fa2e9e8fd031b5e3ad 100644 (file)
--- a/go.mod
+++ b/go.mod
@@ -4,7 +4,7 @@ go 1.13
 
 require (
        github.com/gorilla/mux v1.7.3
-       github.com/mattn/go-sqlite3 v2.0.1+incompatible
+       github.com/mattn/go-sqlite3 v2.0.2+incompatible // indirect
        mastodon v0.0.0-00010101000000-000000000000
 )
 
diff --git a/go.mod.old b/go.mod.old
new file mode 100644 (file)
index 0000000..e633126
--- /dev/null
@@ -0,0 +1,10 @@
+module web
+
+go 1.13
+
+require (
+       github.com/gorilla/mux v1.7.3
+       mastodon v0.0.0-00010101000000-000000000000
+)
+
+replace mastodon => ./mastodon
diff --git a/go.sum b/go.sum
index 236732dba3445447868c9315f88a8dc7a90b8016..7a53570bc0f4f6eee3f3834383845e0dea895f98 100644 (file)
--- a/go.sum
+++ b/go.sum
@@ -2,7 +2,7 @@ github.com/gorilla/mux v1.7.3 h1:gnP5JzjVOuiZD07fKKToCAOjS0yOpj/qPETTXCCS6hw=
 github.com/gorilla/mux v1.7.3/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs=
 github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvKCM=
 github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
-github.com/mattn/go-sqlite3 v2.0.1+incompatible h1:xQ15muvnzGBHpIpdrNi1DA5x0+TcBZzsIDwmw9uTHzw=
-github.com/mattn/go-sqlite3 v2.0.1+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
+github.com/mattn/go-sqlite3 v2.0.2+incompatible h1:qzw9c2GNT8UFrgWNDhCTqRqYUSmu/Dav/9Z58LGpk7U=
+github.com/mattn/go-sqlite3 v2.0.2+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
 github.com/tomnomnom/linkheader v0.0.0-20180905144013-02ca5825eb80 h1:nrZ3ySNYwJbSpD6ce9duiP+QkD3JuLCcWkdaehUS/3Y=
 github.com/tomnomnom/linkheader v0.0.0-20180905144013-02ca5825eb80/go.mod h1:iFyPdL66DjUD96XmzVL3ZntbzcflLnznH0fr99w5VqE=
diff --git a/kv/kv.go b/kv/kv.go
new file mode 100644 (file)
index 0000000..2cfcd60
--- /dev/null
+++ b/kv/kv.go
@@ -0,0 +1,92 @@
+package kv
+
+import (
+       "errors"
+       "io/ioutil"
+       "os"
+       "path/filepath"
+       "strings"
+       "sync"
+)
+
+var (
+       errInvalidKey = errors.New("invalid key")
+       errNoSuchKey  = errors.New("no such key")
+)
+
+type Database struct {
+       data    map[string][]byte
+       basedir string
+       m       sync.RWMutex
+}
+
+func NewDatabse(basedir string) (db *Database, err error) {
+       err = os.Mkdir(basedir, 0755)
+       if err != nil && !os.IsExist(err) {
+               return
+       }
+
+       return &Database{
+               data:    make(map[string][]byte),
+               basedir: basedir,
+       }, nil
+}
+
+func (db *Database) Set(key string, val []byte) (err error) {
+       if len(key) < 1 {
+               return errInvalidKey
+       }
+
+       db.m.Lock()
+       defer func() {
+               if err != nil {
+                       delete(db.data, key)
+               }
+               db.m.Unlock()
+       }()
+
+       db.data[key] = val
+
+       err = ioutil.WriteFile(filepath.Join(db.basedir, key), val, 0644)
+
+       return
+}
+
+func (db *Database) Get(key string) (val []byte, err error) {
+       if len(key) < 1 {
+               return nil, errInvalidKey
+       }
+
+       db.m.RLock()
+       defer db.m.RUnlock()
+
+       data, ok := db.data[key]
+       if !ok {
+               data, err = ioutil.ReadFile(filepath.Join(db.basedir, key))
+               if err != nil {
+                       err = errNoSuchKey
+                       return nil, err
+               }
+
+               db.data[key] = data
+       }
+
+       val = make([]byte, len(data))
+       copy(val, data)
+
+       return
+}
+
+func (db *Database) Remove(key string) {
+       if len(key) < 1 || strings.ContainsRune(key, os.PathSeparator) {
+               return
+       }
+
+       db.m.Lock()
+       defer db.m.Unlock()
+
+       delete(db.data, key)
+       os.Remove(filepath.Join(db.basedir, key))
+
+       return
+}
diff --git a/main.go b/main.go
index d726fedbfb84abfdc0cfb07f02656c4d61078c03..ad629769d3e795c3faba28b04e378b8a10524d06 100644 (file)
--- a/main.go
+++ b/main.go
@@ -1,19 +1,18 @@
 package main
 
 import (
-       "database/sql"
        "log"
        "math/rand"
        "net/http"
        "os"
+       "path/filepath"
        "time"
 
        "web/config"
+       "web/kv"
        "web/renderer"
        "web/repository"
        "web/service"
-
-       _ "github.com/mattn/go-sqlite3"
 )
 
 func init() {
@@ -35,22 +34,24 @@ func main() {
                log.Fatal(err)
        }
 
-       db, err := sql.Open("sqlite3", config.DatabasePath)
-       if err != nil {
+       err = os.Mkdir(config.DatabasePath, 0755)
+       if err != nil && !os.IsExist(err) {
                log.Fatal(err)
        }
-       defer db.Close()
 
-       sessionRepo, err := repository.NewSessionRepository(db)
+       sessionDB, err := kv.NewDatabse(filepath.Join(config.DatabasePath, "session"))
        if err != nil {
                log.Fatal(err)
        }
 
-       appRepo, err := repository.NewAppRepository(db)
+       appDB, err := kv.NewDatabse(filepath.Join(config.DatabasePath, "app"))
        if err != nil {
                log.Fatal(err)
        }
 
+       sessionRepo := repository.NewSessionRepository(sessionDB)
+       appRepo := repository.NewAppRepository(appDB)
+
        var logger *log.Logger
        if len(config.Logfile) < 1 {
                logger = log.New(os.Stdout, "", log.LstdFlags)
index 52ebdf59de5ece9f9997be6ac7dfc93fd4b326b9..89d656d916dfb5c757f3c83b3d67bfc699388bd2 100644 (file)
@@ -1,19 +1,40 @@
 package model
 
-import "errors"
+import (
+       "errors"
+       "strings"
+)
 
 var (
        ErrAppNotFound = errors.New("app not found")
 )
 
 type App struct {
-       InstanceURL  string
-       ClientID     string
-       ClientSecret string
+       InstanceDomain string
+       InstanceURL    string
+       ClientID       string
+       ClientSecret   string
 }
 
 type AppRepository interface {
        Add(app App) (err error)
-       Update(instanceURL string, clientID string, clientSecret string) (err error)
-       Get(instanceURL string) (app App, err error)
+       Get(instanceDomain string) (app App, err error)
+}
+
+func (a *App) Marshal() []byte {
+       str := a.InstanceURL + "\n" + a.ClientID + "\n" + a.ClientSecret
+       return []byte(str)
+}
+
+func (a *App) Unmarshal(instanceDomain string, data []byte) error {
+       str := string(data)
+       lines := strings.Split(str, "\n")
+       if len(lines) != 3 {
+               return errors.New("invalid data")
+       }
+       a.InstanceDomain = instanceDomain
+       a.InstanceURL = lines[0]
+       a.ClientID = lines[1]
+       a.ClientSecret = lines[2]
+       return nil
 }
index 43628ee2c2776b5b0254883008d624a0c4f9bb90..94f527bfdaf3ac37e96a1d9822f88ab6add042c6 100644 (file)
@@ -1,15 +1,18 @@
 package model
 
-import "errors"
+import (
+       "errors"
+       "strings"
+)
 
 var (
        ErrSessionNotFound = errors.New("session not found")
 )
 
 type Session struct {
-       ID          string
-       InstanceURL string
-       AccessToken string
+       ID             string
+       InstanceDomain string
+       AccessToken    string
 }
 
 type SessionRepository interface {
@@ -21,3 +24,26 @@ type SessionRepository interface {
 func (s Session) IsLoggedIn() bool {
        return len(s.AccessToken) > 0
 }
+
+func (s *Session) Marshal() []byte {
+       str := s.InstanceDomain + "\n" + s.AccessToken
+       return []byte(str)
+}
+
+func (s *Session) Unmarshal(id string, data []byte) error {
+       str := string(data)
+       lines := strings.Split(str, "\n")
+
+       size := len(lines)
+       if size == 1 {
+               s.InstanceDomain = lines[0]
+       } else if size == 2 {
+               s.InstanceDomain = lines[0]
+               s.AccessToken = lines[1]
+       } else {
+               return errors.New("invalid data")
+       }
+
+       s.ID = id
+       return nil
+}
index 1a8f20470e90bb607b34c6c589300f63a0f563ba..00ef64d0e87b902fe4522f3c461c94e5dcd17a80 100644 (file)
@@ -1,54 +1,33 @@
 package repository
 
 import (
-       "database/sql"
-
+       "web/kv"
        "web/model"
 )
 
 type appRepository struct {
-       db *sql.DB
+       db *kv.Database
 }
 
-func NewAppRepository(db *sql.DB) (*appRepository, error) {
-       _, err := db.Exec(`CREATE TABLE IF NOT EXISTS app 
-               (instance_url varchar, client_id varchar, client_secret varchar)`,
-       )
-       if err != nil {
-               return nil, err
-       }
-
+func NewAppRepository(db *kv.Database) *appRepository {
        return &appRepository{
                db: db,
-       }, nil
+       }
 }
 
 func (repo *appRepository) Add(a model.App) (err error) {
-       _, err = repo.db.Exec("INSERT INTO app VALUES (?, ?, ?)", a.InstanceURL, a.ClientID, a.ClientSecret)
-       return
-}
-
-func (repo *appRepository) Update(instanceURL string, clientID string, clientSecret string) (err error) {
-       _, err = repo.db.Exec("UPDATE app SET client_id = ?, client_secret = ? where instance_url = ?", clientID, clientSecret, instanceURL)
+       err = repo.db.Set(a.InstanceDomain, a.Marshal())
        return
 }
 
-func (repo *appRepository) Get(instanceURL string) (a model.App, err error) {
-       rows, err := repo.db.Query("SELECT * FROM app WHERE instance_url = ?", instanceURL)
+func (repo *appRepository) Get(instanceDomain string) (a model.App, err error) {
+       data, err := repo.db.Get(instanceDomain)
        if err != nil {
-               return
-       }
-       defer rows.Close()
-
-       if !rows.Next() {
                err = model.ErrAppNotFound
                return
        }
 
-       err = rows.Scan(&a.InstanceURL, &a.ClientID, &a.ClientSecret)
-       if err != nil {
-               return
-       }
+       err = a.Unmarshal(instanceDomain, data)
 
        return
 }
index 2a88b40cbf29c81b390c5c2833a6d1d18fa5764b..6c26313b8897d2c70be6425dfda21edd28f3e47c 100644 (file)
@@ -1,54 +1,50 @@
 package repository
 
 import (
-       "database/sql"
-
+       "web/kv"
        "web/model"
 )
 
 type sessionRepository struct {
-       db *sql.DB
+       db *kv.Database
 }
 
-func NewSessionRepository(db *sql.DB) (*sessionRepository, error) {
-       _, err := db.Exec(`CREATE TABLE IF NOT EXISTS session 
-               (id varchar, instance_url varchar, access_token varchar)`,
-       )
-       if err != nil {
-               return nil, err
-       }
-
+func NewSessionRepository(db *kv.Database) *sessionRepository {
        return &sessionRepository{
                db: db,
-       }, nil
+       }
 }
 
 func (repo *sessionRepository) Add(s model.Session) (err error) {
-       _, err = repo.db.Exec("INSERT INTO session VALUES (?, ?, ?)", s.ID, s.InstanceURL, s.AccessToken)
+       err = repo.db.Set(s.ID, s.Marshal())
        return
 }
 
-func (repo *sessionRepository) Update(sessionID string, accessToken string) (err error) {
-       _, err = repo.db.Exec("UPDATE session SET access_token = ? where id = ?", accessToken, sessionID)
-       return
-}
-
-func (repo *sessionRepository) Get(id string) (s model.Session, err error) {
-       rows, err := repo.db.Query("SELECT * FROM session WHERE id = ?", id)
+func (repo *sessionRepository) Update(id string, accessToken string) (err error) {
+       data, err := repo.db.Get(id)
        if err != nil {
                return
        }
-       defer rows.Close()
 
-       if !rows.Next() {
-               err = model.ErrSessionNotFound
+       var s model.Session
+       err = s.Unmarshal(id, data)
+       if err != nil {
                return
        }
 
-       err = rows.Scan(&s.ID, &s.InstanceURL, &s.AccessToken)
+       s.AccessToken = accessToken
+
+       return repo.db.Set(id, s.Marshal())
+}
+
+func (repo *sessionRepository) Get(id string) (s model.Session, err error) {
+       data, err := repo.db.Get(id)
        if err != nil {
+               err = model.ErrSessionNotFound
                return
        }
 
+       err = s.Unmarshal(id, data)
+
        return
 }
index e9bec3801caf02fbe4cbb2fa9224acec5c7cca61..38c0a43c3eca598c298510a6f5d7453675c6adbc 100644 (file)
@@ -40,12 +40,12 @@ func (s *authService) getClient(ctx context.Context) (c *mastodon.Client, err er
        if err != nil {
                return nil, ErrInvalidSession
        }
-       client, err := s.appRepo.Get(session.InstanceURL)
+       client, err := s.appRepo.Get(session.InstanceDomain)
        if err != nil {
                return
        }
        c = mastodon.NewClient(&mastodon.Config{
-               Server:       session.InstanceURL,
+               Server:       client.InstanceURL,
                ClientID:     client.ClientID,
                ClientSecret: client.ClientSecret,
                AccessToken:  session.AccessToken,
index 5181475c595be82d0a17cd112890131872e4e275..bb03c263401347ed6e37b78c49229d62017e15d0 100644 (file)
@@ -9,7 +9,6 @@ import (
        "mime/multipart"
        "net/http"
        "net/url"
-       "path"
        "strings"
 
        "mastodon"
@@ -64,14 +63,18 @@ func NewService(clientName string, clientScope string, clientWebsite string,
 
 func (svc *service) GetAuthUrl(ctx context.Context, instance string) (
        redirectUrl string, sessionID string, err error) {
-       if !strings.HasPrefix(instance, "https://") {
-               instance = "https://" + instance
+       var instanceURL string
+       if strings.HasPrefix(instance, "https://") {
+               instanceURL = instance
+               instance = strings.TrimPrefix(instance, "https://")
+       } else {
+               instanceURL = "https://" + instance
        }
 
        sessionID = util.NewSessionId()
        err = svc.sessionRepo.Add(model.Session{
-               ID:          sessionID,
-               InstanceURL: instance,
+               ID:             sessionID,
+               InstanceDomain: instance,
        })
        if err != nil {
                return
@@ -85,7 +88,7 @@ func (svc *service) GetAuthUrl(ctx context.Context, instance string) (
 
                var mastoApp *mastodon.Application
                mastoApp, err = mastodon.RegisterApp(ctx, &mastodon.AppConfig{
-                       Server:       instance,
+                       Server:       instanceURL,
                        ClientName:   svc.clientName,
                        Scopes:       svc.clientScope,
                        Website:      svc.clientWebsite,
@@ -96,9 +99,10 @@ func (svc *service) GetAuthUrl(ctx context.Context, instance string) (
                }
 
                app = model.App{
-                       InstanceURL:  instance,
-                       ClientID:     mastoApp.ClientID,
-                       ClientSecret: mastoApp.ClientSecret,
+                       InstanceDomain: instance,
+                       InstanceURL:    instanceURL,
+                       ClientID:       mastoApp.ClientID,
+                       ClientSecret:   mastoApp.ClientSecret,
                }
 
                err = svc.appRepo.Add(app)
@@ -136,7 +140,7 @@ func (svc *service) GetUserToken(ctx context.Context, sessionID string, c *masto
                return
        }
 
-       app, err := svc.appRepo.Get(session.InstanceURL)
+       app, err := svc.appRepo.Get(session.InstanceDomain)
        if err != nil {
                return
        }