Use vendored dependencies
[bloat] / service / auth.go
1 package service
2
3 import (
4         "context"
5         "errors"
6         "mime/multipart"
7
8         "bloat/mastodon"
9         "bloat/model"
10 )
11
12 var (
13         errInvalidSession   = errors.New("invalid session")
14         errInvalidCSRFToken = errors.New("invalid csrf token")
15 )
16
17 type as struct {
18         sessionRepo model.SessionRepo
19         appRepo     model.AppRepo
20         Service
21 }
22
23 func NewAuthService(sessionRepo model.SessionRepo, appRepo model.AppRepo, s Service) Service {
24         return &as{sessionRepo, appRepo, s}
25 }
26
27 func (s *as) authenticateClient(ctx context.Context, c *model.Client) (err error) {
28         sessionID, ok := ctx.Value("session_id").(string)
29         if !ok || len(sessionID) < 1 {
30                 return errInvalidSession
31         }
32         session, err := s.sessionRepo.Get(sessionID)
33         if err != nil {
34                 return errInvalidSession
35         }
36         client, err := s.appRepo.Get(session.InstanceDomain)
37         if err != nil {
38                 return
39         }
40         mc := mastodon.NewClient(&mastodon.Config{
41                 Server:       client.InstanceURL,
42                 ClientID:     client.ClientID,
43                 ClientSecret: client.ClientSecret,
44                 AccessToken:  session.AccessToken,
45         })
46         if c == nil {
47                 c = &model.Client{}
48         }
49         c.Client = mc
50         c.Session = session
51         return nil
52 }
53
54 func checkCSRF(ctx context.Context, c *model.Client) (err error) {
55         token, ok := ctx.Value("csrf_token").(string)
56         if !ok || token != c.Session.CSRFToken {
57                 return errInvalidCSRFToken
58         }
59         return nil
60 }
61
62 func (s *as) ServeErrorPage(ctx context.Context, c *model.Client, err error) {
63         s.authenticateClient(ctx, c)
64         s.Service.ServeErrorPage(ctx, c, err)
65 }
66
67 func (s *as) ServeSigninPage(ctx context.Context, c *model.Client) (err error) {
68         return s.Service.ServeSigninPage(ctx, c)
69 }
70
71 func (s *as) ServeTimelinePage(ctx context.Context, c *model.Client, tType string,
72         maxID string, minID string) (err error) {
73         err = s.authenticateClient(ctx, c)
74         if err != nil {
75                 return
76         }
77         return s.Service.ServeTimelinePage(ctx, c, tType, maxID, minID)
78 }
79
80 func (s *as) ServeThreadPage(ctx context.Context, c *model.Client, id string, reply bool) (err error) {
81         err = s.authenticateClient(ctx, c)
82         if err != nil {
83                 return
84         }
85         return s.Service.ServeThreadPage(ctx, c, id, reply)
86 }
87
88 func (s *as) ServeLikedByPage(ctx context.Context, c *model.Client, id string) (err error) {
89         err = s.authenticateClient(ctx, c)
90         if err != nil {
91                 return
92         }
93         return s.Service.ServeLikedByPage(ctx, c, id)
94 }
95
96 func (s *as) ServeRetweetedByPage(ctx context.Context, c *model.Client, id string) (err error) {
97         err = s.authenticateClient(ctx, c)
98         if err != nil {
99                 return
100         }
101         return s.Service.ServeRetweetedByPage(ctx, c, id)
102 }
103
104 func (s *as) ServeNotificationPage(ctx context.Context, c *model.Client,
105         maxID string, minID string) (err error) {
106         err = s.authenticateClient(ctx, c)
107         if err != nil {
108                 return
109         }
110         return s.Service.ServeNotificationPage(ctx, c, maxID, minID)
111 }
112
113 func (s *as) ServeUserPage(ctx context.Context, c *model.Client, id string,
114         pageType string, maxID string, minID string) (err error) {
115         err = s.authenticateClient(ctx, c)
116         if err != nil {
117                 return
118         }
119         return s.Service.ServeUserPage(ctx, c, id, pageType, maxID, minID)
120 }
121
122 func (s *as) ServeAboutPage(ctx context.Context, c *model.Client) (err error) {
123         err = s.authenticateClient(ctx, c)
124         if err != nil {
125                 return
126         }
127         return s.Service.ServeAboutPage(ctx, c)
128 }
129
130 func (s *as) ServeEmojiPage(ctx context.Context, c *model.Client) (err error) {
131         err = s.authenticateClient(ctx, c)
132         if err != nil {
133                 return
134         }
135         return s.Service.ServeEmojiPage(ctx, c)
136 }
137
138 func (s *as) ServeSearchPage(ctx context.Context, c *model.Client, q string,
139         qType string, offset int) (err error) {
140         err = s.authenticateClient(ctx, c)
141         if err != nil {
142                 return
143         }
144         return s.Service.ServeSearchPage(ctx, c, q, qType, offset)
145 }
146
147 func (s *as) ServeUserSearchPage(ctx context.Context, c *model.Client,
148         id string, q string, offset int) (err error) {
149         err = s.authenticateClient(ctx, c)
150         if err != nil {
151                 return
152         }
153         return s.Service.ServeUserSearchPage(ctx, c, id, q, offset)
154 }
155
156 func (s *as) ServeSettingsPage(ctx context.Context, c *model.Client) (err error) {
157         err = s.authenticateClient(ctx, c)
158         if err != nil {
159                 return
160         }
161         return s.Service.ServeSettingsPage(ctx, c)
162 }
163
164 func (s *as) NewSession(ctx context.Context, instance string) (redirectUrl string,
165         sessionID string, err error) {
166         return s.Service.NewSession(ctx, instance)
167 }
168
169 func (s *as) Signin(ctx context.Context, c *model.Client, sessionID string,
170         code string) (token string, err error) {
171         err = s.authenticateClient(ctx, c)
172         if err != nil {
173                 return
174         }
175
176         token, err = s.Service.Signin(ctx, c, c.Session.ID, code)
177         if err != nil {
178                 return
179         }
180
181         c.Session.AccessToken = token
182         err = s.sessionRepo.Add(c.Session)
183         if err != nil {
184                 return
185         }
186
187         return
188 }
189
190 func (s *as) Post(ctx context.Context, c *model.Client, content string,
191         replyToID string, format string, visibility string, isNSFW bool,
192         files []*multipart.FileHeader) (id string, err error) {
193         err = s.authenticateClient(ctx, c)
194         if err != nil {
195                 return
196         }
197         err = checkCSRF(ctx, c)
198         if err != nil {
199                 return
200         }
201         return s.Service.Post(ctx, c, content, replyToID, format, visibility, isNSFW, files)
202 }
203
204 func (s *as) Like(ctx context.Context, c *model.Client, id string) (count int64, err error) {
205         err = s.authenticateClient(ctx, c)
206         if err != nil {
207                 return
208         }
209         err = checkCSRF(ctx, c)
210         if err != nil {
211                 return
212         }
213         return s.Service.Like(ctx, c, id)
214 }
215
216 func (s *as) UnLike(ctx context.Context, c *model.Client, id string) (count int64, err error) {
217         err = s.authenticateClient(ctx, c)
218         if err != nil {
219                 return
220         }
221         err = checkCSRF(ctx, c)
222         if err != nil {
223                 return
224         }
225         return s.Service.UnLike(ctx, c, id)
226 }
227
228 func (s *as) Retweet(ctx context.Context, c *model.Client, id string) (count int64, err error) {
229         err = s.authenticateClient(ctx, c)
230         if err != nil {
231                 return
232         }
233         err = checkCSRF(ctx, c)
234         if err != nil {
235                 return
236         }
237         return s.Service.Retweet(ctx, c, id)
238 }
239
240 func (s *as) UnRetweet(ctx context.Context, c *model.Client, id string) (count int64, err error) {
241         err = s.authenticateClient(ctx, c)
242         if err != nil {
243                 return
244         }
245         err = checkCSRF(ctx, c)
246         if err != nil {
247                 return
248         }
249         return s.Service.UnRetweet(ctx, c, id)
250 }
251
252 func (s *as) Follow(ctx context.Context, c *model.Client, id string) (err error) {
253         err = s.authenticateClient(ctx, c)
254         if err != nil {
255                 return
256         }
257         err = checkCSRF(ctx, c)
258         if err != nil {
259                 return
260         }
261         return s.Service.Follow(ctx, c, id)
262 }
263
264 func (s *as) UnFollow(ctx context.Context, c *model.Client, id string) (err error) {
265         err = s.authenticateClient(ctx, c)
266         if err != nil {
267                 return
268         }
269         err = checkCSRF(ctx, c)
270         if err != nil {
271                 return
272         }
273         return s.Service.UnFollow(ctx, c, id)
274 }
275
276 func (s *as) SaveSettings(ctx context.Context, c *model.Client, settings *model.Settings) (err error) {
277         err = s.authenticateClient(ctx, c)
278         if err != nil {
279                 return
280         }
281         err = checkCSRF(ctx, c)
282         if err != nil {
283                 return
284         }
285         return s.Service.SaveSettings(ctx, c, settings)
286 }