Refactor everything
[bloat] / service / auth.go
1 package service
2
3 import (
4         "context"
5         "errors"
6         "mime/multipart"
7
8         "bloat/model"
9         "mastodon"
10 )
11
12 var (
13         errInvalidSession   = errors.New("invalid session")
14         errInvalidCSRFToken = errors.New("invalid csrf token")
15 )
16
17 type as struct {
18         sessionRepo model.SessionRepo
19         appRepo     model.AppRepo
20         Service
21 }
22
23 func NewAuthService(sessionRepo model.SessionRepo, appRepo model.AppRepo, s Service) Service {
24         return &as{sessionRepo, appRepo, s}
25 }
26
27 func (s *as) authenticateClient(ctx context.Context, c *model.Client) (err error) {
28         sessionID, ok := ctx.Value("session_id").(string)
29         if !ok || len(sessionID) < 1 {
30                 return errInvalidSession
31         }
32         session, err := s.sessionRepo.Get(sessionID)
33         if err != nil {
34                 return errInvalidSession
35         }
36         client, err := s.appRepo.Get(session.InstanceDomain)
37         if err != nil {
38                 return
39         }
40         mc := mastodon.NewClient(&mastodon.Config{
41                 Server:       client.InstanceURL,
42                 ClientID:     client.ClientID,
43                 ClientSecret: client.ClientSecret,
44                 AccessToken:  session.AccessToken,
45         })
46         if c == nil {
47                 c = &model.Client{}
48         }
49         c.Client = mc
50         c.Session = session
51         return nil
52 }
53
54 func checkCSRF(ctx context.Context, c *model.Client) (err error) {
55         token, ok := ctx.Value("csrf_token").(string)
56         if !ok || token != c.Session.CSRFToken {
57                 return errInvalidCSRFToken
58         }
59         return nil
60 }
61
62 func (s *as) ServeErrorPage(ctx context.Context, c *model.Client, err error) {
63         s.authenticateClient(ctx, c)
64         s.Service.ServeErrorPage(ctx, c, err)
65 }
66
67 func (s *as) ServeSigninPage(ctx context.Context, c *model.Client) (err error) {
68         return s.Service.ServeSigninPage(ctx, c)
69 }
70
71 func (s *as) ServeTimelinePage(ctx context.Context, c *model.Client, tType string,
72         maxID string, minID string) (err error) {
73         err = s.authenticateClient(ctx, c)
74         if err != nil {
75                 return
76         }
77         return s.Service.ServeTimelinePage(ctx, c, tType, maxID, minID)
78 }
79
80 func (s *as) ServeThreadPage(ctx context.Context, c *model.Client, id string, reply bool) (err error) {
81         err = s.authenticateClient(ctx, c)
82         if err != nil {
83                 return
84         }
85         return s.Service.ServeThreadPage(ctx, c, id, reply)
86 }
87
88 func (s *as) ServeLikedByPage(ctx context.Context, c *model.Client, id string) (err error) {
89         err = s.authenticateClient(ctx, c)
90         if err != nil {
91                 return
92         }
93         return s.Service.ServeLikedByPage(ctx, c, id)
94 }
95
96 func (s *as) ServeRetweetedByPage(ctx context.Context, c *model.Client, id string) (err error) {
97         err = s.authenticateClient(ctx, c)
98         if err != nil {
99                 return
100         }
101         return s.Service.ServeRetweetedByPage(ctx, c, id)
102 }
103
104 func (s *as) ServeFollowingPage(ctx context.Context, c *model.Client, id string,
105         maxID string, minID string) (err error) {
106         err = s.authenticateClient(ctx, c)
107         if err != nil {
108                 return
109         }
110         return s.Service.ServeFollowingPage(ctx, c, id, maxID, minID)
111 }
112
113 func (s *as) ServeFollowersPage(ctx context.Context, c *model.Client, id string,
114         maxID string, minID string) (err error) {
115         err = s.authenticateClient(ctx, c)
116         if err != nil {
117                 return
118         }
119         return s.Service.ServeFollowersPage(ctx, c, id, maxID, minID)
120 }
121
122 func (s *as) ServeNotificationPage(ctx context.Context, c *model.Client,
123         maxID string, minID string) (err error) {
124         err = s.authenticateClient(ctx, c)
125         if err != nil {
126                 return
127         }
128         return s.Service.ServeNotificationPage(ctx, c, maxID, minID)
129 }
130
131 func (s *as) ServeUserPage(ctx context.Context, c *model.Client, id string,
132         maxID string, minID string) (err error) {
133         err = s.authenticateClient(ctx, c)
134         if err != nil {
135                 return
136         }
137         return s.Service.ServeUserPage(ctx, c, id, maxID, minID)
138 }
139
140 func (s *as) ServeAboutPage(ctx context.Context, c *model.Client) (err error) {
141         err = s.authenticateClient(ctx, c)
142         if err != nil {
143                 return
144         }
145         return s.Service.ServeAboutPage(ctx, c)
146 }
147
148 func (s *as) ServeEmojiPage(ctx context.Context, c *model.Client) (err error) {
149         err = s.authenticateClient(ctx, c)
150         if err != nil {
151                 return
152         }
153         return s.Service.ServeEmojiPage(ctx, c)
154 }
155
156 func (s *as) ServeSearchPage(ctx context.Context, c *model.Client, q string,
157         qType string, offset int) (err error) {
158         err = s.authenticateClient(ctx, c)
159         if err != nil {
160                 return
161         }
162         return s.Service.ServeSearchPage(ctx, c, q, qType, offset)
163 }
164
165 func (s *as) ServeSettingsPage(ctx context.Context, c *model.Client) (err error) {
166         err = s.authenticateClient(ctx, c)
167         if err != nil {
168                 return
169         }
170         return s.Service.ServeSettingsPage(ctx, c)
171 }
172
173 func (s *as) NewSession(ctx context.Context, instance string) (redirectUrl string,
174         sessionID string, err error) {
175         return s.Service.NewSession(ctx, instance)
176 }
177
178 func (s *as) Signin(ctx context.Context, c *model.Client, sessionID string,
179         code string) (token string, err error) {
180         err = s.authenticateClient(ctx, c)
181         if err != nil {
182                 return
183         }
184
185         token, err = s.Service.Signin(ctx, c, c.Session.ID, code)
186         if err != nil {
187                 return
188         }
189
190         c.Session.AccessToken = token
191         err = s.sessionRepo.Add(c.Session)
192         if err != nil {
193                 return
194         }
195
196         return
197 }
198
199 func (s *as) Post(ctx context.Context, c *model.Client, content string,
200         replyToID string, format string, visibility string, isNSFW bool,
201         files []*multipart.FileHeader) (id string, err error) {
202         err = s.authenticateClient(ctx, c)
203         if err != nil {
204                 return
205         }
206         err = checkCSRF(ctx, c)
207         if err != nil {
208                 return
209         }
210         return s.Service.Post(ctx, c, content, replyToID, format, visibility, isNSFW, files)
211 }
212
213 func (s *as) Like(ctx context.Context, c *model.Client, id string) (count int64, err error) {
214         err = s.authenticateClient(ctx, c)
215         if err != nil {
216                 return
217         }
218         err = checkCSRF(ctx, c)
219         if err != nil {
220                 return
221         }
222         return s.Service.Like(ctx, c, id)
223 }
224
225 func (s *as) UnLike(ctx context.Context, c *model.Client, id string) (count int64, err error) {
226         err = s.authenticateClient(ctx, c)
227         if err != nil {
228                 return
229         }
230         err = checkCSRF(ctx, c)
231         if err != nil {
232                 return
233         }
234         return s.Service.UnLike(ctx, c, id)
235 }
236
237 func (s *as) Retweet(ctx context.Context, c *model.Client, id string) (count int64, err error) {
238         err = s.authenticateClient(ctx, c)
239         if err != nil {
240                 return
241         }
242         err = checkCSRF(ctx, c)
243         if err != nil {
244                 return
245         }
246         return s.Service.Retweet(ctx, c, id)
247 }
248
249 func (s *as) UnRetweet(ctx context.Context, c *model.Client, id string) (count int64, err error) {
250         err = s.authenticateClient(ctx, c)
251         if err != nil {
252                 return
253         }
254         err = checkCSRF(ctx, c)
255         if err != nil {
256                 return
257         }
258         return s.Service.UnRetweet(ctx, c, id)
259 }
260
261 func (s *as) Follow(ctx context.Context, c *model.Client, id string) (err error) {
262         err = s.authenticateClient(ctx, c)
263         if err != nil {
264                 return
265         }
266         err = checkCSRF(ctx, c)
267         if err != nil {
268                 return
269         }
270         return s.Service.Follow(ctx, c, id)
271 }
272
273 func (s *as) UnFollow(ctx context.Context, c *model.Client, id string) (err error) {
274         err = s.authenticateClient(ctx, c)
275         if err != nil {
276                 return
277         }
278         err = checkCSRF(ctx, c)
279         if err != nil {
280                 return
281         }
282         return s.Service.UnFollow(ctx, c, id)
283 }
284
285 func (s *as) SaveSettings(ctx context.Context, c *model.Client, settings *model.Settings) (err error) {
286         err = s.authenticateClient(ctx, c)
287         if err != nil {
288                 return
289         }
290         err = checkCSRF(ctx, c)
291         if err != nil {
292                 return
293         }
294         return s.Service.SaveSettings(ctx, c, settings)
295 }