Add CSRF protection
[bloat] / service / transport.go
1 package service
2
3 import (
4         "context"
5         "encoding/json"
6         "io"
7         "mime/multipart"
8         "net/http"
9         "path"
10         "strconv"
11         "time"
12
13         "bloat/model"
14
15         "github.com/gorilla/mux"
16 )
17
18 var (
19         ctx       = context.Background()
20         cookieAge = "31536000"
21 )
22
23 func NewHandler(s Service, staticDir string) http.Handler {
24         r := mux.NewRouter()
25
26         r.PathPrefix("/static").Handler(http.StripPrefix("/static",
27                 http.FileServer(http.Dir(path.Join(".", staticDir)))))
28
29         r.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
30                 location := "/signin"
31
32                 sessionID, _ := req.Cookie("session_id")
33                 if sessionID != nil && len(sessionID.Value) > 0 {
34                         location = "/timeline/home"
35                 }
36
37                 w.Header().Add("Location", location)
38                 w.WriteHeader(http.StatusFound)
39         }).Methods(http.MethodGet)
40
41         r.HandleFunc("/signin", func(w http.ResponseWriter, req *http.Request) {
42                 err := s.ServeSigninPage(ctx, w)
43                 if err != nil {
44                         s.ServeErrorPage(ctx, w, nil, err)
45                         return
46                 }
47         }).Methods(http.MethodGet)
48
49         r.HandleFunc("/signin", func(w http.ResponseWriter, req *http.Request) {
50                 instance := req.FormValue("instance")
51                 url, sessionID, err := s.GetAuthUrl(ctx, instance)
52                 if err != nil {
53                         s.ServeErrorPage(ctx, w, nil, err)
54                         return
55                 }
56
57                 http.SetCookie(w, &http.Cookie{
58                         Name:    "session_id",
59                         Value:   sessionID,
60                         Expires: time.Now().Add(365 * 24 * time.Hour),
61                 })
62
63                 w.Header().Add("Location", url)
64                 w.WriteHeader(http.StatusFound)
65         }).Methods(http.MethodPost)
66
67         r.HandleFunc("/oauth_callback", func(w http.ResponseWriter, req *http.Request) {
68                 ctx := getContextWithSession(context.Background(), req)
69                 token := req.URL.Query().Get("code")
70                 _, err := s.GetUserToken(ctx, "", nil, token)
71                 if err != nil {
72                         s.ServeErrorPage(ctx, w, nil, err)
73                         return
74                 }
75
76                 w.Header().Add("Location", "/timeline/home")
77                 w.WriteHeader(http.StatusFound)
78         }).Methods(http.MethodGet)
79
80         r.HandleFunc("/timeline", func(w http.ResponseWriter, req *http.Request) {
81                 w.Header().Add("Location", "/timeline/home")
82                 w.WriteHeader(http.StatusFound)
83         }).Methods(http.MethodGet)
84
85         r.HandleFunc("/timeline/{type}", func(w http.ResponseWriter, req *http.Request) {
86                 ctx := getContextWithSession(context.Background(), req)
87
88                 timelineType, _ := mux.Vars(req)["type"]
89                 maxID := req.URL.Query().Get("max_id")
90                 sinceID := req.URL.Query().Get("since_id")
91                 minID := req.URL.Query().Get("min_id")
92
93                 err := s.ServeTimelinePage(ctx, w, nil, timelineType, maxID, sinceID, minID)
94                 if err != nil {
95                         s.ServeErrorPage(ctx, w, nil, err)
96                         return
97                 }
98         }).Methods(http.MethodGet)
99
100         r.HandleFunc("/thread/{id}", func(w http.ResponseWriter, req *http.Request) {
101                 ctx := getContextWithSession(context.Background(), req)
102                 id, _ := mux.Vars(req)["id"]
103                 reply := req.URL.Query().Get("reply")
104                 err := s.ServeThreadPage(ctx, w, nil, id, len(reply) > 1)
105                 if err != nil {
106                         s.ServeErrorPage(ctx, w, nil, err)
107                         return
108                 }
109         }).Methods(http.MethodGet)
110
111         r.HandleFunc("/likedby/{id}", func(w http.ResponseWriter, req *http.Request) {
112                 ctx := getContextWithSession(context.Background(), req)
113                 id, _ := mux.Vars(req)["id"]
114
115                 err := s.ServeLikedByPage(ctx, w, nil, id)
116                 if err != nil {
117                         s.ServeErrorPage(ctx, w, nil, err)
118                         return
119                 }
120         }).Methods(http.MethodGet)
121
122         r.HandleFunc("/retweetedby/{id}", func(w http.ResponseWriter, req *http.Request) {
123                 ctx := getContextWithSession(context.Background(), req)
124                 id, _ := mux.Vars(req)["id"]
125
126                 err := s.ServeRetweetedByPage(ctx, w, nil, id)
127                 if err != nil {
128                         s.ServeErrorPage(ctx, w, nil, err)
129                         return
130                 }
131         }).Methods(http.MethodGet)
132
133         r.HandleFunc("/following/{id}", func(w http.ResponseWriter, req *http.Request) {
134                 ctx := getContextWithSession(context.Background(), req)
135
136                 id, _ := mux.Vars(req)["id"]
137                 maxID := req.URL.Query().Get("max_id")
138                 minID := req.URL.Query().Get("min_id")
139
140                 err := s.ServeFollowingPage(ctx, w, nil, id, maxID, minID)
141                 if err != nil {
142                         s.ServeErrorPage(ctx, w, nil, err)
143                         return
144                 }
145         }).Methods(http.MethodGet)
146
147         r.HandleFunc("/followers/{id}", func(w http.ResponseWriter, req *http.Request) {
148                 ctx := getContextWithSession(context.Background(), req)
149
150                 id, _ := mux.Vars(req)["id"]
151                 maxID := req.URL.Query().Get("max_id")
152                 minID := req.URL.Query().Get("min_id")
153
154                 err := s.ServeFollowersPage(ctx, w, nil, id, maxID, minID)
155                 if err != nil {
156                         s.ServeErrorPage(ctx, w, nil, err)
157                         return
158                 }
159         }).Methods(http.MethodGet)
160
161         r.HandleFunc("/like/{id}", func(w http.ResponseWriter, req *http.Request) {
162                 ctx := getContextWithSession(context.Background(), req)
163                 ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token"))
164
165                 id, _ := mux.Vars(req)["id"]
166                 retweetedByID := req.FormValue("retweeted_by_id")
167
168                 _, err := s.Like(ctx, w, nil, id)
169                 if err != nil {
170                         s.ServeErrorPage(ctx, w, nil, err)
171                         return
172                 }
173
174                 rID := id
175                 if len(retweetedByID) > 0 {
176                         rID = retweetedByID
177                 }
178                 w.Header().Add("Location", req.Header.Get("Referer")+"#status-"+rID)
179                 w.WriteHeader(http.StatusFound)
180         }).Methods(http.MethodPost)
181
182         r.HandleFunc("/unlike/{id}", func(w http.ResponseWriter, req *http.Request) {
183                 ctx := getContextWithSession(context.Background(), req)
184                 ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token"))
185
186                 id, _ := mux.Vars(req)["id"]
187                 retweetedByID := req.FormValue("retweeted_by_id")
188
189                 _, err := s.UnLike(ctx, w, nil, id)
190                 if err != nil {
191                         s.ServeErrorPage(ctx, w, nil, err)
192                         return
193                 }
194
195                 rID := id
196                 if len(retweetedByID) > 0 {
197                         rID = retweetedByID
198                 }
199                 w.Header().Add("Location", req.Header.Get("Referer")+"#status-"+rID)
200                 w.WriteHeader(http.StatusFound)
201         }).Methods(http.MethodPost)
202
203         r.HandleFunc("/retweet/{id}", func(w http.ResponseWriter, req *http.Request) {
204                 ctx := getContextWithSession(context.Background(), req)
205                 ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token"))
206
207                 id, _ := mux.Vars(req)["id"]
208                 retweetedByID := req.FormValue("retweeted_by_id")
209
210                 _, err := s.Retweet(ctx, w, nil, id)
211                 if err != nil {
212                         s.ServeErrorPage(ctx, w, nil, err)
213                         return
214                 }
215
216                 rID := id
217                 if len(retweetedByID) > 0 {
218                         rID = retweetedByID
219                 }
220                 w.Header().Add("Location", req.Header.Get("Referer")+"#status-"+rID)
221                 w.WriteHeader(http.StatusFound)
222         }).Methods(http.MethodPost)
223
224         r.HandleFunc("/unretweet/{id}", func(w http.ResponseWriter, req *http.Request) {
225                 ctx := getContextWithSession(context.Background(), req)
226                 ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token"))
227
228                 id, _ := mux.Vars(req)["id"]
229                 retweetedByID := req.FormValue("retweeted_by_id")
230
231                 _, err := s.UnRetweet(ctx, w, nil, id)
232                 if err != nil {
233                         s.ServeErrorPage(ctx, w, nil, err)
234                         return
235                 }
236
237                 rID := id
238                 if len(retweetedByID) > 0 {
239                         rID = retweetedByID
240                 }
241                 w.Header().Add("Location", req.Header.Get("Referer")+"#status-"+rID)
242                 w.WriteHeader(http.StatusFound)
243         }).Methods(http.MethodPost)
244
245         r.HandleFunc("/fluoride/like/{id}", func(w http.ResponseWriter, req *http.Request) {
246                 ctx := getContextWithSession(context.Background(), req)
247                 ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token"))
248
249                 id, _ := mux.Vars(req)["id"]
250                 count, err := s.Like(ctx, w, nil, id)
251                 if err != nil {
252                         s.ServeErrorPage(ctx, w, nil, err)
253                         return
254                 }
255
256                 err = serveJson(w, count)
257                 if err != nil {
258                         s.ServeErrorPage(ctx, w, nil, err)
259                         return
260                 }
261         }).Methods(http.MethodPost)
262
263         r.HandleFunc("/fluoride/unlike/{id}", func(w http.ResponseWriter, req *http.Request) {
264                 ctx := getContextWithSession(context.Background(), req)
265                 ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token"))
266
267                 id, _ := mux.Vars(req)["id"]
268                 count, err := s.UnLike(ctx, w, nil, id)
269                 if err != nil {
270                         s.ServeErrorPage(ctx, w, nil, err)
271                         return
272                 }
273
274                 err = serveJson(w, count)
275                 if err != nil {
276                         s.ServeErrorPage(ctx, w, nil, err)
277                         return
278                 }
279         }).Methods(http.MethodPost)
280
281         r.HandleFunc("/fluoride/retweet/{id}", func(w http.ResponseWriter, req *http.Request) {
282                 ctx := getContextWithSession(context.Background(), req)
283                 ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token"))
284
285                 id, _ := mux.Vars(req)["id"]
286                 count, err := s.Retweet(ctx, w, nil, id)
287                 if err != nil {
288                         s.ServeErrorPage(ctx, w, nil, err)
289                         return
290                 }
291
292                 err = serveJson(w, count)
293                 if err != nil {
294                         s.ServeErrorPage(ctx, w, nil, err)
295                         return
296                 }
297         }).Methods(http.MethodPost)
298
299         r.HandleFunc("/fluoride/unretweet/{id}", func(w http.ResponseWriter, req *http.Request) {
300                 ctx := getContextWithSession(context.Background(), req)
301                 ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token"))
302
303                 id, _ := mux.Vars(req)["id"]
304                 count, err := s.UnRetweet(ctx, w, nil, id)
305                 if err != nil {
306                         s.ServeErrorPage(ctx, w, nil, err)
307                         return
308                 }
309
310                 err = serveJson(w, count)
311                 if err != nil {
312                         s.ServeErrorPage(ctx, w, nil, err)
313                         return
314                 }
315         }).Methods(http.MethodPost)
316
317         r.HandleFunc("/post", func(w http.ResponseWriter, req *http.Request) {
318                 err := req.ParseMultipartForm(4 << 20)
319                 if err != nil {
320                         s.ServeErrorPage(ctx, w, nil, err)
321                         return
322                 }
323
324                 ctx := getContextWithSession(context.Background(), req)
325                 ctx = context.WithValue(ctx, "csrf_token",
326                         getMultipartFormValue(req.MultipartForm, "csrf_token"))
327
328                 content := getMultipartFormValue(req.MultipartForm, "content")
329                 replyToID := getMultipartFormValue(req.MultipartForm, "reply_to_id")
330                 format := getMultipartFormValue(req.MultipartForm, "format")
331                 visibility := getMultipartFormValue(req.MultipartForm, "visibility")
332                 isNSFW := "on" == getMultipartFormValue(req.MultipartForm, "is_nsfw")
333
334                 files := req.MultipartForm.File["attachments"]
335
336                 id, err := s.PostTweet(ctx, w, nil, content, replyToID, format, visibility, isNSFW, files)
337                 if err != nil {
338                         s.ServeErrorPage(ctx, w, nil, err)
339                         return
340                 }
341
342                 location := "/timeline/home" + "#status-" + id
343                 if len(replyToID) > 0 {
344                         location = "/thread/" + replyToID + "#status-" + id
345                 }
346                 w.Header().Add("Location", location)
347                 w.WriteHeader(http.StatusFound)
348         }).Methods(http.MethodPost)
349
350         r.HandleFunc("/notifications", func(w http.ResponseWriter, req *http.Request) {
351                 ctx := getContextWithSession(context.Background(), req)
352
353                 maxID := req.URL.Query().Get("max_id")
354                 minID := req.URL.Query().Get("min_id")
355
356                 err := s.ServeNotificationPage(ctx, w, nil, maxID, minID)
357                 if err != nil {
358                         s.ServeErrorPage(ctx, w, nil, err)
359                         return
360                 }
361         }).Methods(http.MethodGet)
362
363         r.HandleFunc("/user/{id}", func(w http.ResponseWriter, req *http.Request) {
364                 ctx := getContextWithSession(context.Background(), req)
365
366                 id, _ := mux.Vars(req)["id"]
367                 maxID := req.URL.Query().Get("max_id")
368                 minID := req.URL.Query().Get("min_id")
369
370                 err := s.ServeUserPage(ctx, w, nil, id, maxID, minID)
371                 if err != nil {
372                         s.ServeErrorPage(ctx, w, nil, err)
373                         return
374                 }
375         }).Methods(http.MethodGet)
376
377         r.HandleFunc("/follow/{id}", func(w http.ResponseWriter, req *http.Request) {
378                 ctx := getContextWithSession(context.Background(), req)
379                 ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token"))
380
381                 id, _ := mux.Vars(req)["id"]
382
383                 err := s.Follow(ctx, w, nil, id)
384                 if err != nil {
385                         s.ServeErrorPage(ctx, w, nil, err)
386                         return
387                 }
388
389                 w.Header().Add("Location", req.Header.Get("Referer"))
390                 w.WriteHeader(http.StatusFound)
391         }).Methods(http.MethodPost)
392
393         r.HandleFunc("/unfollow/{id}", func(w http.ResponseWriter, req *http.Request) {
394                 ctx := getContextWithSession(context.Background(), req)
395                 ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token"))
396
397                 id, _ := mux.Vars(req)["id"]
398
399                 err := s.UnFollow(ctx, w, nil, id)
400                 if err != nil {
401                         s.ServeErrorPage(ctx, w, nil, err)
402                         return
403                 }
404
405                 w.Header().Add("Location", req.Header.Get("Referer"))
406                 w.WriteHeader(http.StatusFound)
407         }).Methods(http.MethodPost)
408
409         r.HandleFunc("/about", func(w http.ResponseWriter, req *http.Request) {
410                 ctx := getContextWithSession(context.Background(), req)
411
412                 err := s.ServeAboutPage(ctx, w, nil)
413                 if err != nil {
414                         s.ServeErrorPage(ctx, w, nil, err)
415                         return
416                 }
417         }).Methods(http.MethodGet)
418
419         r.HandleFunc("/emojis", func(w http.ResponseWriter, req *http.Request) {
420                 ctx := getContextWithSession(context.Background(), req)
421
422                 err := s.ServeEmojiPage(ctx, w, nil)
423                 if err != nil {
424                         s.ServeErrorPage(ctx, w, nil, err)
425                         return
426                 }
427         }).Methods(http.MethodGet)
428
429         r.HandleFunc("/search", func(w http.ResponseWriter, req *http.Request) {
430                 ctx := getContextWithSession(context.Background(), req)
431
432                 q := req.URL.Query().Get("q")
433                 qType := req.URL.Query().Get("type")
434                 offsetStr := req.URL.Query().Get("offset")
435
436                 var offset int
437                 var err error
438                 if len(offsetStr) > 1 {
439                         offset, err = strconv.Atoi(offsetStr)
440                         if err != nil {
441                                 s.ServeErrorPage(ctx, w, nil, err)
442                                 return
443                         }
444                 }
445
446                 err = s.ServeSearchPage(ctx, w, nil, q, qType, offset)
447                 if err != nil {
448                         s.ServeErrorPage(ctx, w, nil, err)
449                         return
450                 }
451         }).Methods(http.MethodGet)
452
453         r.HandleFunc("/settings", func(w http.ResponseWriter, req *http.Request) {
454                 ctx := getContextWithSession(context.Background(), req)
455
456                 err := s.ServeSettingsPage(ctx, w, nil)
457                 if err != nil {
458                         s.ServeErrorPage(ctx, w, nil, err)
459                         return
460                 }
461         }).Methods(http.MethodGet)
462
463         r.HandleFunc("/settings", func(w http.ResponseWriter, req *http.Request) {
464                 ctx := getContextWithSession(context.Background(), req)
465                 ctx = context.WithValue(ctx, "csrf_token", req.FormValue("csrf_token"))
466
467                 visibility := req.FormValue("visibility")
468                 copyScope := req.FormValue("copy_scope") == "true"
469                 threadInNewTab := req.FormValue("thread_in_new_tab") == "true"
470                 maskNSFW := req.FormValue("mask_nsfw") == "true"
471                 fluorideMode := req.FormValue("fluoride_mode") == "true"
472                 darkMode := req.FormValue("dark_mode") == "true"
473                 settings := &model.Settings{
474                         DefaultVisibility: visibility,
475                         CopyScope:         copyScope,
476                         ThreadInNewTab:    threadInNewTab,
477                         MaskNSFW:          maskNSFW,
478                         FluorideMode:      fluorideMode,
479                         DarkMode:          darkMode,
480                 }
481
482                 err := s.SaveSettings(ctx, w, nil, settings)
483                 if err != nil {
484                         s.ServeErrorPage(ctx, w, nil, err)
485                         return
486                 }
487
488                 w.Header().Add("Location", req.Header.Get("Referer"))
489                 w.WriteHeader(http.StatusFound)
490         }).Methods(http.MethodPost)
491
492         r.HandleFunc("/signout", func(w http.ResponseWriter, req *http.Request) {
493                 // TODO remove session from database
494                 http.SetCookie(w, &http.Cookie{
495                         Name:    "session_id",
496                         Value:   "",
497                         Expires: time.Now(),
498                 })
499                 w.Header().Add("Location", "/")
500                 w.WriteHeader(http.StatusFound)
501         }).Methods(http.MethodGet)
502
503         return r
504 }
505
506 func getContextWithSession(ctx context.Context, req *http.Request) context.Context {
507         sessionID, err := req.Cookie("session_id")
508         if err != nil {
509                 return ctx
510         }
511         return context.WithValue(ctx, "session_id", sessionID.Value)
512 }
513
514 func getMultipartFormValue(mf *multipart.Form, key string) (val string) {
515         vals, ok := mf.Value[key]
516         if !ok {
517                 return ""
518         }
519         if len(vals) < 1 {
520                 return ""
521         }
522         return vals[0]
523 }
524
525 func serveJson(w io.Writer, data interface{}) (err error) {
526         var d = make(map[string]interface{})
527         d["data"] = data
528         return json.NewEncoder(w).Encode(d)
529 }