// Package subimporter implements a bulk ZIP/CSV importer of subscribers. // It implements a simple queue for buffering imports and committing records // to DB along with ZIP and CSV handling utilities. It is meant to be used as // a singleton as each Importer instance is stateful, where it keeps track of // an import in progress. Only one import should happen on a single importer // instance at a time. package subimporter import ( "archive/zip" "bytes" "database/sql" "encoding/csv" "encoding/json" "errors" "fmt" "io" "log" "net/mail" "os" "regexp" "strings" "sync" "github.com/gofrs/uuid" "github.com/knadh/listmonk/internal/i18n" "github.com/knadh/listmonk/models" "github.com/lib/pq" ) const ( // stdInputMaxLen is the maximum allowed length for a standard input field. stdInputMaxLen = 200 // commitBatchSize is the number of inserts to commit in a single SQL transaction. commitBatchSize = 10000 ) // Various import statuses. const ( StatusNone = "none" StatusImporting = "importing" StatusStopping = "stopping" StatusFinished = "finished" StatusFailed = "failed" ModeSubscribe = "subscribe" ModeBlocklist = "blocklist" ) // Importer represents the bulk CSV subscriber import system. type Importer struct { opt Options db *sql.DB i18n *i18n.I18n domainBlocklist map[string]bool hasBlocklistWildcards bool stop chan bool status Status sync.RWMutex } // Options represents import options. type Options struct { UpsertStmt *sql.Stmt BlocklistStmt *sql.Stmt UpdateListDateStmt *sql.Stmt NotifCB models.AdminNotifCallback // Lookup table for blocklisted domains. DomainBlocklist []string } // Session represents a single import session. type Session struct { im *Importer subQueue chan SubReq log *log.Logger opt SessionOpt } // SessionOpt represents the options for an importer session. type SessionOpt struct { Filename string `json:"filename"` Mode string `json:"mode"` SubStatus string `json:"subscription_status"` Overwrite bool `json:"overwrite"` Delim string `json:"delim"` ListIDs []int `json:"lists"` } // Status represents statistics from an ongoing import session. type Status struct { Name string `json:"name"` Total int `json:"total"` Imported int `json:"imported"` Status string `json:"status"` logBuf *bytes.Buffer } // SubReq is a wrapper over the Subscriber model. type SubReq struct { models.Subscriber Lists []int `json:"lists"` ListUUIDs []string `json:"list_uuids"` PreconfirmSubs bool `json:"preconfirm_subscriptions"` } type importStatusTpl struct { Name string Status string Imported int Total int } var ( // ErrIsImporting is thrown when an import request is made while an // import is already running. ErrIsImporting = errors.New("import is already running") csvHeaders = map[string]bool{ "email": true, "name": true, "attributes": true} regexCleanStr = regexp.MustCompile("[[:^ascii:]]") ) // New returns a new instance of Importer. func New(opt Options, db *sql.DB, i *i18n.I18n) *Importer { im := Importer{ opt: opt, db: db, i18n: i, domainBlocklist: make(map[string]bool, len(opt.DomainBlocklist)), status: Status{Status: StatusNone, logBuf: bytes.NewBuffer(nil)}, stop: make(chan bool, 1), } // Domain blocklist. for _, d := range opt.DomainBlocklist { im.domainBlocklist[d] = true // Domains with *. as the subdomain prefix, strip that // and add the full domain to the blocklist as well. // eg: *.example.com => example.com if strings.Contains(d, "*.") { im.hasBlocklistWildcards = true im.domainBlocklist[strings.TrimPrefix(d, "*.")] = true } } return &im } // NewSession returns an new instance of Session. It takes the name // of the uploaded file, but doesn't do anything with it but retains it for stats. func (im *Importer) NewSession(opt SessionOpt) (*Session, error) { if !im.isDone() { return nil, errors.New("an import is already running") } im.Lock() im.status = Status{Status: StatusImporting, Name: opt.Filename, logBuf: bytes.NewBuffer(nil)} im.Unlock() s := &Session{ im: im, log: log.New(im.status.logBuf, "", log.Ldate|log.Ltime|log.Lshortfile), subQueue: make(chan SubReq, commitBatchSize), opt: opt, } s.log.Printf("processing '%s'", opt.Filename) return s, nil } // GetStats returns the global Stats of the importer. func (im *Importer) GetStats() Status { im.RLock() defer im.RUnlock() return Status{ Name: im.status.Name, Status: im.status.Status, Total: im.status.Total, Imported: im.status.Imported, } } // GetLogs returns the log entries of the last import session. func (im *Importer) GetLogs() []byte { im.RLock() defer im.RUnlock() if im.status.logBuf == nil { return []byte{} } return im.status.logBuf.Bytes() } // setStatus sets the Importer's status. func (im *Importer) setStatus(status string) { im.Lock() im.status.Status = status im.Unlock() } // getStatus get's the Importer's status. func (im *Importer) getStatus() string { im.RLock() status := im.status.Status im.RUnlock() return status } // isDone returns true if the importer is working (importing|stopping). func (im *Importer) isDone() bool { s := true im.RLock() if im.getStatus() == StatusImporting || im.getStatus() == StatusStopping { s = false } im.RUnlock() return s } // incrementImportCount sets the Importer's "imported" counter. func (im *Importer) incrementImportCount(n int) { im.Lock() im.status.Imported += n im.Unlock() } // sendNotif sends admin notifications for import completions. func (im *Importer) sendNotif(status string) error { var ( s = im.GetStats() out = importStatusTpl{ Name: s.Name, Status: status, Imported: s.Imported, Total: s.Total, } subject = fmt.Sprintf("%s: %s import", strings.Title(status), s.Name) ) return im.opt.NotifCB(subject, out) } // Start is a blocking function that selects on a channel queue until all // subscriber entries in the import session are imported. It should be // invoked as a goroutine. func (s *Session) Start() { var ( tx *sql.Tx stmt *sql.Stmt err error total = 0 cur = 0 listIDs = make([]int, len(s.opt.ListIDs)) ) for i, v := range s.opt.ListIDs { listIDs[i] = v } for sub := range s.subQueue { if cur == 0 { // New transaction batch. tx, err = s.im.db.Begin() if err != nil { s.log.Printf("error creating DB transaction: %v", err) continue } if s.opt.Mode == ModeSubscribe { stmt = tx.Stmt(s.im.opt.UpsertStmt) } else { stmt = tx.Stmt(s.im.opt.BlocklistStmt) } } uu, err := uuid.NewV4() if err != nil { s.log.Printf("error generating UUID: %v", err) tx.Rollback() break } if s.opt.Mode == ModeSubscribe { _, err = stmt.Exec(uu, sub.Email, sub.Name, sub.Attribs, pq.Array(listIDs), s.opt.SubStatus, s.opt.Overwrite) } else if s.opt.Mode == ModeBlocklist { _, err = stmt.Exec(uu, sub.Email, sub.Name, sub.Attribs) } if err != nil { s.log.Printf("error executing insert: %v", err) tx.Rollback() break } cur++ total++ // Batch size is met. Commit. if cur%commitBatchSize == 0 { if err := tx.Commit(); err != nil { tx.Rollback() s.log.Printf("error committing to DB: %v", err) } else { s.im.incrementImportCount(cur) s.log.Printf("imported %d", total) } cur = 0 } } // Queue's closed and there's nothing left to commit. if cur == 0 { s.im.setStatus(StatusFinished) s.log.Printf("imported finished") if _, err := s.im.opt.UpdateListDateStmt.Exec(pq.Array(listIDs)); err != nil { s.log.Printf("error updating lists date: %v", err) } s.im.sendNotif(StatusFinished) return } // Queue's closed and there are records left to commit. if err := tx.Commit(); err != nil { tx.Rollback() s.im.setStatus(StatusFailed) s.log.Printf("error committing to DB: %v", err) s.im.sendNotif(StatusFailed) return } s.im.incrementImportCount(cur) s.im.setStatus(StatusFinished) s.log.Printf("imported finished") if _, err := s.im.opt.UpdateListDateStmt.Exec(pq.Array(listIDs)); err != nil { s.log.Printf("error updating lists date: %v", err) } s.im.sendNotif(StatusFinished) } // Stop stops an active import session. func (s *Session) Stop() { close(s.subQueue) } // ExtractZIP takes a ZIP file's path and extracts all .csv files in it to // a temporary directory, and returns the name of the temp directory and the // list of extracted .csv files. func (s *Session) ExtractZIP(srcPath string, maxCSVs int) (string, []string, error) { if s.im.isDone() { return "", nil, ErrIsImporting } failed := true defer func() { if failed { s.im.setStatus(StatusFailed) } }() z, err := zip.OpenReader(srcPath) if err != nil { return "", nil, err } defer z.Close() // Create a temporary directory to extract the files. dir, err := os.MkdirTemp("", "listmonk") if err != nil { s.log.Printf("error creating temporary directory for extracting ZIP: %v", err) return "", nil, err } files := make([]string, 0, len(z.File)) for _, f := range z.File { fName := f.FileInfo().Name() // Skip directories. if f.FileInfo().IsDir() { s.log.Printf("skipping directory '%s'", fName) continue } // Skip files without the .csv extension. if !strings.HasSuffix(strings.ToLower(fName), ".csv") { s.log.Printf("skipping non .csv file '%s'", fName) continue } s.log.Printf("extracting '%s'", fName) src, err := f.Open() if err != nil { s.log.Printf("error opening '%s' from ZIP: '%v'", fName, err) return "", nil, err } defer src.Close() out, err := os.OpenFile(dir+"/"+fName, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode()) if err != nil { s.log.Printf("error creating '%s/%s': '%v'", dir, fName, err) return "", nil, err } defer out.Close() if _, err := io.Copy(out, src); err != nil { s.log.Printf("error extracting to '%s/%s': '%v'", dir, fName, err) return "", nil, err } s.log.Printf("extracted '%s'", fName) files = append(files, fName) if len(files) > maxCSVs { s.log.Printf("won't extract any more files. Maximum is %d", maxCSVs) break } } if len(files) == 0 { s.log.Println("no CSV files found in the ZIP") return "", nil, errors.New("no CSV files found in the ZIP") } failed = false return dir, files, nil } // LoadCSV loads a CSV file and validates and imports the subscriber entries in it. func (s *Session) LoadCSV(srcPath string, delim rune) error { if s.im.isDone() { return ErrIsImporting } // Default status is "failed" in case the function // returns at one of the many possible errors. failed := true defer func() { if failed { s.im.setStatus(StatusFailed) } }() f, err := os.Open(srcPath) if err != nil { return err } // Count the total number of lines in the file. This doesn't distinguish // between "blank" and non "blank" lines, and is only used to derive // the progress percentage for the frontend. numLines, err := countLines(f) if err != nil { s.log.Printf("error counting lines in '%s': '%v'", srcPath, err) return err } if numLines == 0 { return errors.New("empty file") } s.im.Lock() // Exclude the header from count. s.im.status.Total = numLines - 1 s.im.Unlock() // Rewind, now that we've done a linecount on the same handler. _, _ = f.Seek(0, 0) rd := csv.NewReader(f) rd.Comma = delim // Read the header. csvHdr, err := rd.Read() if err != nil { s.log.Printf("error reading header from '%s': '%v'", srcPath, err) return err } hdrKeys := s.mapCSVHeaders(csvHdr, csvHeaders) // email is a required header. if _, ok := hdrKeys["email"]; !ok { s.log.Printf("'email' column not found in '%s'", srcPath) return errors.New("'email' column not found") } var ( lnHdr = len(hdrKeys) i = 0 ) for { i++ // Check for the stop signal. select { case <-s.im.stop: failed = false close(s.subQueue) s.log.Println("stop request received") return nil default: } cols, err := rd.Read() if err == io.EOF { break } else if err != nil { if err, ok := err.(*csv.ParseError); ok && err.Err == csv.ErrFieldCount { s.log.Printf("skipping line %d. %v", i, err) continue } else { s.log.Printf("error reading CSV '%s'", err) return err } } lnCols := len(cols) if lnCols < lnHdr { s.log.Printf("skipping line %d. column count (%d) does not match minimum header count (%d)", i, lnCols, lnHdr) continue } // Iterate the key map and based on the indices mapped earlier, // form a map of key: csv_value, eg: email: user@user.com. row := make(map[string]string, lnCols) for key := range hdrKeys { row[key] = cols[hdrKeys[key]] } sub := SubReq{} sub.Email = row["email"] if v, ok := row["name"]; ok { sub.Name = v } sub, err = s.im.ValidateFields(sub) if err != nil { s.log.Printf("skipping line %d: %s: %v", i, sub.Email, err) continue } // JSON attributes. if len(row["attributes"]) > 0 { var ( attribs models.JSON b = []byte(row["attributes"]) ) if err := json.Unmarshal(b, &attribs); err != nil { s.log.Printf("skipping invalid attributes JSON on line %d for '%s': %v", i, sub.Email, err) } else { sub.Attribs = attribs } } // Send the subscriber to the queue. s.subQueue <- sub } close(s.subQueue) failed = false return nil } // Stop sends a signal to stop the existing import. func (im *Importer) Stop() { if im.getStatus() != StatusImporting { im.Lock() im.status = Status{Status: StatusNone} im.Unlock() return } select { case im.stop <- true: im.setStatus(StatusStopping) default: } } // SanitizeEmail validates and sanitizes an e-mail string and returns the lowercased, // e-mail component of an e-mail string. func (im *Importer) SanitizeEmail(email string) (string, error) { email = strings.ToLower(strings.TrimSpace(email)) // Since `mail.ParseAddress` parses an email address which can also contain optional name component // here we check if incoming email string is same as the parsed email.Address. So this eliminates // any valid email address with name and also valid address with empty name like ``. em, err := mail.ParseAddress(email) if err != nil || em.Address != email { return "", errors.New(im.i18n.T("subscribers.invalidEmail")) } // Check if the e-mail's domain is blocklisted. The e-mail domain and blocklist config // are always lowercase. d := strings.Split(em.Address, "@") if len(d) == 2 { domain := d[1] // Check the domain as-is. if _, ok := im.domainBlocklist[domain]; ok { return "", errors.New(im.i18n.T("subscribers.domainBlocklisted")) } // If there are wildcards in the blocklist and the email domain has a subdomain, check that. if im.hasBlocklistWildcards && strings.Count(domain, ".") > 1 { parts := strings.Split(domain, ".") // Replace the first part of the subdomain with * and check if that exists in the blocklist. // Eg: test.mail.example.com => *.mail.example.com parts[0] = "*" domain = strings.Join(parts, ".") if _, ok := im.domainBlocklist[domain]; ok { return "", errors.New(im.i18n.T("subscribers.domainBlocklisted")) } } } return em.Address, nil } // ValidateFields validates incoming subscriber field values and returns sanitized fields. func (im *Importer) ValidateFields(s SubReq) (SubReq, error) { if len(s.Email) > 1000 { return s, errors.New(im.i18n.T("subscribers.invalidEmail")) } em, err := im.SanitizeEmail(s.Email) if err != nil { return s, err } s.Email = strings.ToLower(em) // If there's no name, use the name part of the e-mail. s.Name = strings.TrimSpace(s.Name) if len(s.Name) == 0 { name := strings.ToLower(strings.Split(s.Email, "@")[0]) parts := strings.Fields(strings.ReplaceAll(name, ".", " ")) for n, p := range parts { parts[n] = strings.Title(p) } s.Name = strings.Join(parts, " ") } return s, nil } // mapCSVHeaders takes a list of headers obtained from a CSV file, a map of known headers, // and returns a new map with each of the headers in the known map mapped by the position (0-n) // in the given CSV list. func (s *Session) mapCSVHeaders(csvHdrs []string, knownHdrs map[string]bool) map[string]int { // Map 0-n column index to the header keys, name: 0, email: 1 etc. // This is to allow dynamic ordering of columns in th CSV. hdrKeys := make(map[string]int) for i, h := range csvHdrs { // Clean the string of non-ASCII characters (BOM etc.). h := regexCleanStr.ReplaceAllString(h, "") if _, ok := knownHdrs[h]; !ok { s.log.Printf("ignoring unknown header '%s'", h) continue } hdrKeys[h] = i } return hdrKeys } // countLines counts the number of line breaks in a file. This does not // distinguish between "blank" and non "blank" lines. // Credit: https://stackoverflow.com/a/24563853 func countLines(r io.Reader) (int, error) { var ( buf = make([]byte, 32*1024) count = 0 lineSep = []byte{'\n'} ) for { c, err := r.Read(buf) count += bytes.Count(buf[:c], lineSep) switch { case err == io.EOF: return count, nil case err != nil: return count, err } } }