909a9a29bc6a86a236d66fa64a8054c5b4c809d6
[bloat] / service / auth.go
1 package service
2
3 import (
4         "context"
5         "errors"
6         "io"
7         "mime/multipart"
8
9         "bloat/model"
10         "mastodon"
11 )
12
13 var (
14         ErrInvalidSession   = errors.New("invalid session")
15         ErrInvalidCSRFToken = errors.New("invalid csrf token")
16 )
17
18 type authService struct {
19         sessionRepo model.SessionRepository
20         appRepo     model.AppRepository
21         Service
22 }
23
24 func NewAuthService(sessionRepo model.SessionRepository, appRepo model.AppRepository, s Service) Service {
25         return &authService{sessionRepo, appRepo, s}
26 }
27
28 func (s *authService) getClient(ctx context.Context) (c *model.Client, err error) {
29         sessionID, ok := ctx.Value("session_id").(string)
30         if !ok || len(sessionID) < 1 {
31                 return nil, ErrInvalidSession
32         }
33         session, err := s.sessionRepo.Get(sessionID)
34         if err != nil {
35                 return nil, ErrInvalidSession
36         }
37         client, err := s.appRepo.Get(session.InstanceDomain)
38         if err != nil {
39                 return
40         }
41         mc := mastodon.NewClient(&mastodon.Config{
42                 Server:       client.InstanceURL,
43                 ClientID:     client.ClientID,
44                 ClientSecret: client.ClientSecret,
45                 AccessToken:  session.AccessToken,
46         })
47         c = &model.Client{Client: mc, Session: session}
48         return c, nil
49 }
50
51 func checkCSRF(ctx context.Context, c *model.Client) (err error) {
52         csrfToken, ok := ctx.Value("csrf_token").(string)
53         if !ok || csrfToken != c.Session.CSRFToken {
54                 return ErrInvalidCSRFToken
55         }
56         return nil
57 }
58
59 func (s *authService) GetAuthUrl(ctx context.Context, instance string) (
60         redirectUrl string, sessionID string, err error) {
61         return s.Service.GetAuthUrl(ctx, instance)
62 }
63
64 func (s *authService) GetUserToken(ctx context.Context, sessionID string, c *model.Client,
65         code string) (token string, err error) {
66         c, err = s.getClient(ctx)
67         if err != nil {
68                 return
69         }
70
71         token, err = s.Service.GetUserToken(ctx, c.Session.ID, c, code)
72         if err != nil {
73                 return
74         }
75
76         c.Session.AccessToken = token
77         err = s.sessionRepo.Add(c.Session)
78         if err != nil {
79                 return
80         }
81
82         return
83 }
84
85 func (s *authService) ServeErrorPage(ctx context.Context, client io.Writer, c *model.Client, err error) {
86         c, _ = s.getClient(ctx)
87         s.Service.ServeErrorPage(ctx, client, c, err)
88 }
89
90 func (s *authService) ServeSigninPage(ctx context.Context, client io.Writer) (err error) {
91         return s.Service.ServeSigninPage(ctx, client)
92 }
93
94 func (s *authService) ServeTimelinePage(ctx context.Context, client io.Writer,
95         c *model.Client, timelineType string, maxID string, sinceID string, minID string) (err error) {
96         c, err = s.getClient(ctx)
97         if err != nil {
98                 return
99         }
100         return s.Service.ServeTimelinePage(ctx, client, c, timelineType, maxID, sinceID, minID)
101 }
102
103 func (s *authService) ServeThreadPage(ctx context.Context, client io.Writer, c *model.Client, id string, reply bool) (err error) {
104         c, err = s.getClient(ctx)
105         if err != nil {
106                 return
107         }
108         return s.Service.ServeThreadPage(ctx, client, c, id, reply)
109 }
110
111 func (s *authService) ServeNotificationPage(ctx context.Context, client io.Writer, c *model.Client, maxID string, minID string) (err error) {
112         c, err = s.getClient(ctx)
113         if err != nil {
114                 return
115         }
116         return s.Service.ServeNotificationPage(ctx, client, c, maxID, minID)
117 }
118
119 func (s *authService) ServeUserPage(ctx context.Context, client io.Writer, c *model.Client, id string, maxID string, minID string) (err error) {
120         c, err = s.getClient(ctx)
121         if err != nil {
122                 return
123         }
124         return s.Service.ServeUserPage(ctx, client, c, id, maxID, minID)
125 }
126
127 func (s *authService) ServeAboutPage(ctx context.Context, client io.Writer, c *model.Client) (err error) {
128         c, err = s.getClient(ctx)
129         if err != nil {
130                 return
131         }
132         return s.Service.ServeAboutPage(ctx, client, c)
133 }
134
135 func (s *authService) ServeEmojiPage(ctx context.Context, client io.Writer, c *model.Client) (err error) {
136         c, err = s.getClient(ctx)
137         if err != nil {
138                 return
139         }
140         return s.Service.ServeEmojiPage(ctx, client, c)
141 }
142
143 func (s *authService) ServeLikedByPage(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) {
144         c, err = s.getClient(ctx)
145         if err != nil {
146                 return
147         }
148         return s.Service.ServeLikedByPage(ctx, client, c, id)
149 }
150
151 func (s *authService) ServeRetweetedByPage(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) {
152         c, err = s.getClient(ctx)
153         if err != nil {
154                 return
155         }
156         return s.Service.ServeRetweetedByPage(ctx, client, c, id)
157 }
158
159 func (s *authService) ServeFollowingPage(ctx context.Context, client io.Writer, c *model.Client, id string, maxID string, minID string) (err error) {
160         c, err = s.getClient(ctx)
161         if err != nil {
162                 return
163         }
164         return s.Service.ServeFollowingPage(ctx, client, c, id, maxID, minID)
165 }
166
167 func (s *authService) ServeFollowersPage(ctx context.Context, client io.Writer, c *model.Client, id string, maxID string, minID string) (err error) {
168         c, err = s.getClient(ctx)
169         if err != nil {
170                 return
171         }
172         return s.Service.ServeFollowersPage(ctx, client, c, id, maxID, minID)
173 }
174
175 func (s *authService) ServeSearchPage(ctx context.Context, client io.Writer, c *model.Client, q string, qType string, offset int) (err error) {
176         c, err = s.getClient(ctx)
177         if err != nil {
178                 return
179         }
180         return s.Service.ServeSearchPage(ctx, client, c, q, qType, offset)
181 }
182
183 func (s *authService) ServeSettingsPage(ctx context.Context, client io.Writer, c *model.Client) (err error) {
184         c, err = s.getClient(ctx)
185         if err != nil {
186                 return
187         }
188         return s.Service.ServeSettingsPage(ctx, client, c)
189 }
190
191 func (s *authService) SaveSettings(ctx context.Context, client io.Writer, c *model.Client, settings *model.Settings) (err error) {
192         c, err = s.getClient(ctx)
193         if err != nil {
194                 return
195         }
196         err = checkCSRF(ctx, c)
197         if err != nil {
198                 return
199         }
200         return s.Service.SaveSettings(ctx, client, c, settings)
201 }
202
203 func (s *authService) Like(ctx context.Context, client io.Writer, c *model.Client, id string) (count int64, err error) {
204         c, err = s.getClient(ctx)
205         if err != nil {
206                 return
207         }
208         err = checkCSRF(ctx, c)
209         if err != nil {
210                 return
211         }
212         return s.Service.Like(ctx, client, c, id)
213 }
214
215 func (s *authService) UnLike(ctx context.Context, client io.Writer, c *model.Client, id string) (count int64, err error) {
216         c, err = s.getClient(ctx)
217         if err != nil {
218                 return
219         }
220         err = checkCSRF(ctx, c)
221         if err != nil {
222                 return
223         }
224         return s.Service.UnLike(ctx, client, c, id)
225 }
226
227 func (s *authService) Retweet(ctx context.Context, client io.Writer, c *model.Client, id string) (count int64, err error) {
228         c, err = s.getClient(ctx)
229         if err != nil {
230                 return
231         }
232         err = checkCSRF(ctx, c)
233         if err != nil {
234                 return
235         }
236         return s.Service.Retweet(ctx, client, c, id)
237 }
238
239 func (s *authService) UnRetweet(ctx context.Context, client io.Writer, c *model.Client, id string) (count int64, err error) {
240         c, err = s.getClient(ctx)
241         if err != nil {
242                 return
243         }
244         err = checkCSRF(ctx, c)
245         if err != nil {
246                 return
247         }
248         return s.Service.UnRetweet(ctx, client, c, id)
249 }
250
251 func (s *authService) PostTweet(ctx context.Context, client io.Writer, c *model.Client, content string, replyToID string, format string, visibility string, isNSFW bool, files []*multipart.FileHeader) (id string, err error) {
252         c, err = s.getClient(ctx)
253         if err != nil {
254                 return
255         }
256         err = checkCSRF(ctx, c)
257         if err != nil {
258                 return
259         }
260         return s.Service.PostTweet(ctx, client, c, content, replyToID, format, visibility, isNSFW, files)
261 }
262
263 func (s *authService) Follow(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) {
264         c, err = s.getClient(ctx)
265         if err != nil {
266                 return
267         }
268         err = checkCSRF(ctx, c)
269         if err != nil {
270                 return
271         }
272         return s.Service.Follow(ctx, client, c, id)
273 }
274
275 func (s *authService) UnFollow(ctx context.Context, client io.Writer, c *model.Client, id string) (err error) {
276         c, err = s.getClient(ctx)
277         if err != nil {
278                 return
279         }
280         err = checkCSRF(ctx, c)
281         if err != nil {
282                 return
283         }
284         return s.Service.UnFollow(ctx, client, c, id)
285 }