diff --git a/.gitignore b/.gitignore index 7332997a8..b3c96ff69 100644 --- a/.gitignore +++ b/.gitignore @@ -53,3 +53,5 @@ _testmain.go # Custom dev scripts win_dev_* +go.work +go.work.sum diff --git a/README.md b/README.md index c57104c10..3ee3ae070 100644 --- a/README.md +++ b/README.md @@ -88,12 +88,12 @@ Portmaster is a privacy suite for your desktop OS. All details and guides live in the dedicated [wiki](https://wiki.safing.io/) -- [Getting Started](https://wiki.safing.io/en/Portmaster/App/GettingStarted) +- [Getting Started](https://wiki.safing.io/en/Portmaster/App) - Install - [on Windows](https://wiki.safing.io/en/Portmaster/Install/Windows) - [on Linux](https://wiki.safing.io/en/Portmaster/Install/Linux) - [Contribute](https://wiki.safing.io/en/Contribute) - [VPN Compatibility](https://wiki.safing.io/en/Portmaster/App/Compatibility#vpn-compatibly) - [Software Compatibility](https://wiki.safing.io/en/Portmaster/App/Compatibility) -- [Architecture](https://wiki.safing.io/en/Portmaster/Architecture/Overview) +- [Architecture](https://wiki.safing.io/en/Portmaster/Architecture) diff --git a/compat/selfcheck.go b/compat/selfcheck.go index c1508d121..4515d93c5 100644 --- a/compat/selfcheck.go +++ b/compat/selfcheck.go @@ -28,12 +28,12 @@ var ( systemIntegrationCheckDialNet = fmt.Sprintf("ip4:%d", uint8(SystemIntegrationCheckProtocol)) systemIntegrationCheckDialIP = SystemIntegrationCheckDstIP.String() systemIntegrationCheckPackets = make(chan packet.Packet, 1) - systemIntegrationCheckWaitDuration = 20 * time.Second + systemIntegrationCheckWaitDuration = 40 * time.Second // DNSCheckInternalDomainScope is the domain scope to use for dns checks. DNSCheckInternalDomainScope = ".self-check." + resolver.InternalSpecialUseDomain dnsCheckReceivedDomain = make(chan string, 1) - dnsCheckWaitDuration = 20 * time.Second + dnsCheckWaitDuration = 40 * time.Second dnsCheckAnswerLock sync.Mutex dnsCheckAnswer net.IP ) diff --git a/firewall/config.go b/firewall/config.go index eaf3fa442..ea1785b3b 100644 --- a/firewall/config.go +++ b/firewall/config.go @@ -23,11 +23,11 @@ var ( askTimeout config.IntOption CfgOptionPermanentVerdictsKey = "filter/permanentVerdicts" - cfgOptionPermanentVerdictsOrder = 96 + cfgOptionPermanentVerdictsOrder = 80 permanentVerdicts config.BoolOption CfgOptionDNSQueryInterceptionKey = "filter/dnsQueryInterception" - cfgOptionDNSQueryInterceptionOrder = 97 + cfgOptionDNSQueryInterceptionOrder = 81 dnsQueryInterception config.BoolOption ) diff --git a/firewall/interception/ebpf/bandwidth/interface.go b/firewall/interception/ebpf/bandwidth/interface.go index f23d44522..f247b157f 100644 --- a/firewall/interception/ebpf/bandwidth/interface.go +++ b/firewall/interception/ebpf/bandwidth/interface.go @@ -133,15 +133,18 @@ func reportBandwidth(ctx context.Context, objs bpfObjects, bandwidthUpdates chan false, ) update := &packet.BandwidthUpdate{ - ConnID: connID, - RecvBytes: skInfo.Rx, - SentBytes: skInfo.Tx, - Method: packet.Absolute, + ConnID: connID, + BytesReceived: skInfo.Rx, + BytesSent: skInfo.Tx, + Method: packet.Absolute, } select { case bandwidthUpdates <- update: case <-ctx.Done(): return + default: + log.Warning("ebpf: bandwidth update queue is full, skipping rest of batch") + return } } } diff --git a/firewall/interception/ebpf/connection_listener/worker.go b/firewall/interception/ebpf/connection_listener/worker.go index d8aced122..1dee07be4 100644 --- a/firewall/interception/ebpf/connection_listener/worker.go +++ b/firewall/interception/ebpf/connection_listener/worker.go @@ -8,6 +8,7 @@ import ( "fmt" "net" "sync/atomic" + "time" "github.com/cilium/ebpf/link" "github.com/cilium/ebpf/ringbuf" @@ -112,9 +113,11 @@ func ConnectionListenerWorker(ctx context.Context, packets chan packet.Packet) e Src: convertArrayToIPv4(event.Saddr, packet.IPVersion(event.IpVersion)), Dst: convertArrayToIPv4(event.Daddr, packet.IPVersion(event.IpVersion)), PID: int(event.Pid), + SeenAt: time.Now(), }) if isEventValid(event) { - log.Debugf("ebpf: received valid connect event: PID: %d Conn: %s", pkt.Info().PID, pkt) + // DEBUG: + // log.Debugf("ebpf: received valid connect event: PID: %d Conn: %s", pkt.Info().PID, pkt) packets <- pkt } else { log.Warningf("ebpf: received invalid connect event: PID: %d Conn: %s", pkt.Info().PID, pkt) diff --git a/firewall/interception/nfq/nfq.go b/firewall/interception/nfq/nfq.go index 585ba96e4..184e15f94 100644 --- a/firewall/interception/nfq/nfq.go +++ b/firewall/interception/nfq/nfq.go @@ -196,7 +196,8 @@ func (q *Queue) packetHandler(ctx context.Context) func(nfqueue.Attribute) int { select { case q.packets <- pkt: - log.Tracef("nfqueue: queued packet %s (%s -> %s) after %s", pkt.ID(), pkt.Info().Src, pkt.Info().Dst, time.Since(pkt.Info().SeenAt)) + // DEBUG: + // log.Tracef("nfqueue: queued packet %s (%s -> %s) after %s", pkt.ID(), pkt.Info().Src, pkt.Info().Dst, time.Since(pkt.Info().SeenAt)) case <-ctx.Done(): return 0 case <-time.After(time.Second): diff --git a/firewall/interception/nfq/packet.go b/firewall/interception/nfq/packet.go index 6dd421861..8baeff5be 100644 --- a/firewall/interception/nfq/packet.go +++ b/firewall/interception/nfq/packet.go @@ -6,7 +6,6 @@ import ( "errors" "fmt" "sync/atomic" - "time" "github.com/florianl/go-nfqueue" "github.com/tevino/abool" @@ -117,7 +116,13 @@ func (pkt *packet) setMark(mark int) error { } break } - log.Tracer(pkt.Ctx()).Tracef("nfqueue: marking packet %s (%s -> %s) on queue %d with %s after %s", pkt.ID(), pkt.Info().Src, pkt.Info().Dst, pkt.queue.id, markToString(mark), time.Since(pkt.Info().SeenAt)) + + // DEBUG: + // log.Tracer(pkt.Ctx()).Tracef( + // "nfqueue: marking packet %s (%s -> %s) on queue %d with %s after %s", + // pkt.ID(), pkt.Info().Src, pkt.Info().Dst, pkt.queue.id, + // markToString(mark), time.Since(pkt.Info().SeenAt), + // ) return nil } diff --git a/firewall/interception/windowskext/bandwidth_stats.go b/firewall/interception/windowskext/bandwidth_stats.go index 6e9dd05f7..2a1bddc0f 100644 --- a/firewall/interception/windowskext/bandwidth_stats.go +++ b/firewall/interception/windowskext/bandwidth_stats.go @@ -55,7 +55,7 @@ func reportBandwidth(ctx context.Context, bandwidthUpdates chan *packet.Bandwidt } // Report all statistics. - for _, stat := range stats { + for i, stat := range stats { connID := packet.CreateConnectionID( packet.IPProtocol(stat.protocol), convertArrayToIP(stat.localIP, stat.ipV6 == 1), stat.localPort, @@ -63,15 +63,18 @@ func reportBandwidth(ctx context.Context, bandwidthUpdates chan *packet.Bandwidt false, ) update := &packet.BandwidthUpdate{ - ConnID: connID, - RecvBytes: stat.receivedBytes, - SentBytes: stat.transmittedBytes, - Method: packet.Additive, + ConnID: connID, + BytesReceived: stat.receivedBytes, + BytesSent: stat.transmittedBytes, + Method: packet.Additive, } select { case bandwidthUpdates <- update: case <-ctx.Done(): return nil + default: + log.Warningf("kext: bandwidth update queue is full, skipping rest of batch (%d entries)", len(stats)-i) + return nil } } diff --git a/firewall/module.go b/firewall/module.go index 345316c00..dd4dcbaac 100644 --- a/firewall/module.go +++ b/firewall/module.go @@ -14,7 +14,7 @@ import ( var module *modules.Module func init() { - module = modules.Register("filter", prep, start, stop, "core", "interception", "intel") + module = modules.Register("filter", prep, start, stop, "core", "interception", "intel", "netquery") subsystems.Register( "filter", "Privacy Filter", diff --git a/firewall/packet_handler.go b/firewall/packet_handler.go index 4fc783ba8..97df6eff7 100644 --- a/firewall/packet_handler.go +++ b/firewall/packet_handler.go @@ -18,6 +18,7 @@ import ( "github.com/safing/portmaster/firewall/inspection" "github.com/safing/portmaster/firewall/interception" "github.com/safing/portmaster/netenv" + "github.com/safing/portmaster/netquery" "github.com/safing/portmaster/network" "github.com/safing/portmaster/network/netutils" "github.com/safing/portmaster/network/packet" @@ -510,7 +511,7 @@ func issueVerdict(conn *network.Connection, pkt packet.Packet, verdict network.V atomic.AddUint64(packetsFailed, 1) err = pkt.Drop() case network.VerdictUndecided, network.VerdictUndeterminable: - log.Warningf("filter: tried to apply verdict %s to pkt %s: dropping instead", verdict, pkt) + log.Tracer(pkt.Ctx()).Warningf("filter: tried to apply verdict %s to pkt %s: dropping instead", verdict, pkt) fallthrough default: atomic.AddUint64(packetsDropped, 1) @@ -518,7 +519,7 @@ func issueVerdict(conn *network.Connection, pkt packet.Packet, verdict network.V } if err != nil { - log.Warningf("filter: failed to apply verdict to pkt %s: %s", pkt, err) + log.Tracer(pkt.Ctx()).Warningf("filter: failed to apply verdict to pkt %s: %s", pkt, err) } } @@ -616,7 +617,7 @@ func bandwidthUpdateHandler(ctx context.Context) error { return nil case bwUpdate := <-interception.BandwidthUpdates: if bwUpdate != nil { - updateBandwidth(bwUpdate) + updateBandwidth(ctx, bwUpdate) // DEBUG: // log.Debugf("filter: bandwidth update: %s", bwUpdate) } else { @@ -626,9 +627,9 @@ func bandwidthUpdateHandler(ctx context.Context) error { } } -func updateBandwidth(bwUpdate *packet.BandwidthUpdate) { +func updateBandwidth(ctx context.Context, bwUpdate *packet.BandwidthUpdate) { // Check if update makes sense. - if bwUpdate.RecvBytes == 0 && bwUpdate.SentBytes == 0 { + if bwUpdate.BytesReceived == 0 && bwUpdate.BytesSent == 0 { return } @@ -648,16 +649,29 @@ func updateBandwidth(bwUpdate *packet.BandwidthUpdate) { // Update stats according to method. switch bwUpdate.Method { case packet.Absolute: - conn.RecvBytes = bwUpdate.RecvBytes - conn.SentBytes = bwUpdate.SentBytes + conn.BytesReceived = bwUpdate.BytesReceived + conn.BytesSent = bwUpdate.BytesSent case packet.Additive: - conn.RecvBytes += bwUpdate.RecvBytes - conn.SentBytes += bwUpdate.SentBytes + conn.BytesReceived += bwUpdate.BytesReceived + conn.BytesSent += bwUpdate.BytesSent default: log.Warningf("filter: unsupported bandwidth update method: %d", bwUpdate.Method) + return } - // TODO: Send update. + // Update bandwidth in the netquery module. + if netquery.DefaultModule != nil && conn.BandwidthEnabled { + if err := netquery.DefaultModule.Store.UpdateBandwidth( + ctx, + conn.HistoryEnabled, + conn.Process().GetID(), + conn.ID, + conn.BytesReceived, + conn.BytesSent, + ); err != nil { + log.Errorf("filter: failed to persist bandwidth data: %s", err) + } + } } func statLogger(ctx context.Context) error { diff --git a/firewall/prompt.go b/firewall/prompt.go index e1a380d90..e3582ba02 100644 --- a/firewall/prompt.go +++ b/firewall/prompt.go @@ -91,12 +91,12 @@ func createPrompt(ctx context.Context, conn *network.Connection) (n *notificatio layeredProfile := conn.Process().Profile() if layeredProfile == nil { log.Tracer(ctx).Warningf("filter: tried creating prompt for connection without profile") - return + return nil } localProfile := layeredProfile.LocalProfile() if localProfile == nil { log.Tracer(ctx).Warningf("filter: tried creating prompt for connection without local profile") - return + return nil } // first check if there is an existing notification for this. diff --git a/go.mod b/go.mod index 69483277c..a046cc5db 100644 --- a/go.mod +++ b/go.mod @@ -20,7 +20,7 @@ require ( github.com/safing/jess v0.3.1 github.com/safing/portbase v0.17.0 github.com/safing/portmaster-android/go v0.0.0-20230605085256-6abf4c495626 - github.com/safing/spn v0.6.8 + github.com/safing/spn v0.6.9 github.com/shirou/gopsutil v3.21.11+incompatible github.com/spf13/cobra v1.7.0 github.com/spkg/zipfs v0.7.1 @@ -28,6 +28,7 @@ require ( github.com/tannerryan/ring v1.1.2 github.com/tevino/abool v1.2.0 github.com/umahmood/haversine v0.0.0-20151105152445-808ab04add26 + golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 golang.org/x/net v0.12.0 golang.org/x/sync v0.3.0 golang.org/x/sys v0.10.0 @@ -86,7 +87,6 @@ require ( github.com/zalando/go-keyring v0.2.3 // indirect go.etcd.io/bbolt v1.3.7 // indirect golang.org/x/crypto v0.11.0 // indirect - golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 // indirect golang.org/x/mod v0.12.0 // indirect golang.org/x/time v0.3.0 // indirect golang.org/x/tools v0.11.0 // indirect diff --git a/go.sum b/go.sum index 61148a1b4..ef2bba870 100644 --- a/go.sum +++ b/go.sum @@ -210,8 +210,8 @@ github.com/safing/portbase v0.17.0 h1:RsDzbCGxRIbgaArri3y7MZskfxytEvvkzJpiboDUER github.com/safing/portbase v0.17.0/go.mod h1:eKCRqsfMFLVhNpd2sY/fKvnbuk+LrIYnQEZCg1i86Ho= github.com/safing/portmaster-android/go v0.0.0-20230605085256-6abf4c495626 h1:olc/REnUdpJN/Gmz8B030OxLpMYxyPDTrDILNEw0eKs= github.com/safing/portmaster-android/go v0.0.0-20230605085256-6abf4c495626/go.mod h1:abwyAQrZGemWbSh/aCD9nnkp0SvFFf/mGWkAbOwPnFE= -github.com/safing/spn v0.6.8 h1:2obvyMzyw5X3CIYedLBE88kNBBrJumF84q1qtQSFqkc= -github.com/safing/spn v0.6.8/go.mod h1:Mh9bmkqFhO/dHNi9RWXzoXjQij893I4Lj8Wn4tQ0KZA= +github.com/safing/spn v0.6.9 h1:CCRN5jgshJrLBHwGHl0ywWwhukc+Wff7/I66qgYyymg= +github.com/safing/spn v0.6.9/go.mod h1:Mh9bmkqFhO/dHNi9RWXzoXjQij893I4Lj8Wn4tQ0KZA= github.com/satori/go.uuid v1.2.0 h1:0uYX9dsZ2yD7q2RtLRtPSdGDWzjeM3TbMJP9utgA0ww= github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= github.com/seehuhn/fortuna v1.0.1 h1:lu9+CHsmR0bZnx5Ay646XvCSRJ8PJTi5UYJwDBX68H0= diff --git a/netquery/database.go b/netquery/database.go index 0434d3c17..f1abc6333 100644 --- a/netquery/database.go +++ b/netquery/database.go @@ -2,18 +2,23 @@ package netquery import ( "context" + "crypto/sha256" + "encoding/hex" "encoding/json" "fmt" "io" + "path" "sort" "strings" "sync" "time" + "github.com/hashicorp/go-multierror" "github.com/jackc/puddle/v2" "zombiezen.com/go/sqlite" "zombiezen.com/go/sqlite/sqlitex" + "github.com/safing/portbase/dataroot" "github.com/safing/portbase/log" "github.com/safing/portmaster/netquery/orm" "github.com/safing/portmaster/network" @@ -22,7 +27,7 @@ import ( ) // InMemory is the "file path" to open a new in-memory database. -const InMemory = "file:inmem.db" +const InMemory = "file:inmem.db?mode=memory" // Available connection types as their string representation. const ( @@ -46,6 +51,7 @@ type ( Schema *orm.TableSchema readConnPool *puddle.Pool[*sqlite.Conn] + historyPath string l sync.Mutex writeConn *sqlite.Conn @@ -82,7 +88,9 @@ type ( Latitude float64 `sqlite:"latitude"` Longitude float64 `sqlite:"longitude"` Scope netutils.IPScope `sqlite:"scope"` - Verdict network.Verdict `sqlite:"verdict"` + WorstVerdict network.Verdict `sqlite:"worst_verdict"` + ActiveVerdict network.Verdict `sqlite:"verdict"` + FirewallVerdict network.Verdict `sqlite:"firewall_verdict"` Started time.Time `sqlite:"started,text,time"` Ended *time.Time `sqlite:"ended,text,time"` Tunneled bool `sqlite:"tunneled"` @@ -93,6 +101,8 @@ type ( Allowed *bool `sqlite:"allowed"` ProfileRevision int `sqlite:"profile_revision"` ExitNode *string `sqlite:"exit_node"` + BytesReceived uint64 `sqlite:"bytes_received,default=0"` + BytesSent uint64 `sqlite:"bytes_sent,default=0"` // TODO(ppacher): support "NOT" in search query to get rid of the following helper fields Active bool `sqlite:"active"` // could use "ended IS NOT NULL" or "ended IS NULL" @@ -102,24 +112,33 @@ type ( } ) -// New opens a new in-memory database named path. +// New opens a new in-memory database named path and attaches a persistent history database. // // The returned Database used connection pooling for read-only connections // (see Execute). To perform database writes use either Save() or ExecuteWrite(). // Note that write connections are serialized by the Database object before being // handed over to SQLite. -func New(path string) (*Database, error) { +func New(dbPath string) (*Database, error) { + historyParentDir := dataroot.Root().ChildDir("databases", 0o700) + if err := historyParentDir.Ensure(); err != nil { + return nil, fmt.Errorf("failed to ensure database directory exists: %w", err) + } + + historyPath := "file://" + path.Join(historyParentDir.Path, "history.db") + constructor := func(ctx context.Context) (*sqlite.Conn, error) { c, err := sqlite.OpenConn( - path, + dbPath, sqlite.OpenReadOnly, - sqlite.OpenNoMutex, //nolint:staticcheck // We like to be explicit. sqlite.OpenSharedCache, - sqlite.OpenMemory, sqlite.OpenURI, ) if err != nil { - return nil, fmt.Errorf("failed to open read-only sqlite connection at %s: %w", path, err) + return nil, fmt.Errorf("failed to open read-only sqlite connection at %s: %w", dbPath, err) + } + + if err := sqlitex.ExecuteTransient(c, "ATTACH DATABASE '"+historyPath+"?mode=ro' AS history", nil); err != nil { + return nil, fmt.Errorf("failed to attach history database: %w", err) } return c, nil @@ -146,23 +165,22 @@ func New(path string) (*Database, error) { } writeConn, err := sqlite.OpenConn( - path, + dbPath, sqlite.OpenCreate, sqlite.OpenReadWrite, - sqlite.OpenNoMutex, //nolint:staticcheck // We like to be explicit. sqlite.OpenWAL, sqlite.OpenSharedCache, - sqlite.OpenMemory, sqlite.OpenURI, ) if err != nil { - return nil, fmt.Errorf("failed to open sqlite at %s: %w", path, err) + return nil, fmt.Errorf("failed to open sqlite at %s: %w", dbPath, err) } return &Database{ readConnPool: pool, Schema: schema, writeConn: writeConn, + historyPath: historyPath, }, nil } @@ -189,28 +207,42 @@ func NewInMemory() (*Database, error) { // any data-migrations. Once the history module is implemented this should // become/use a full migration system -- use zombiezen.com/go/sqlite/sqlitemigration. func (db *Database) ApplyMigrations() error { - // get the create-table SQL statement from the inferred schema - sql := db.Schema.CreateStatement(true) - + log.Errorf("applying migrations ...") db.l.Lock() defer db.l.Unlock() - // execute the SQL - if err := sqlitex.ExecuteTransient(db.writeConn, sql, nil); err != nil { - return fmt.Errorf("failed to create schema: %w", err) + if err := sqlitex.ExecuteTransient(db.writeConn, "ATTACH DATABASE '"+db.historyPath+"?mode=rwc' AS 'history';", nil); err != nil { + return fmt.Errorf("failed to attach history database: %w", err) } - // create a few indexes - indexes := []string{ - `CREATE INDEX profile_id_index ON %s (profile)`, - `CREATE INDEX started_time_index ON %s (strftime('%%s', started)+0)`, - `CREATE INDEX started_ended_time_index ON %s (strftime('%%s', started)+0, strftime('%%s', ended)+0) WHERE ended IS NOT NULL`, - } - for _, idx := range indexes { - stmt := fmt.Sprintf(idx, db.Schema.Name) + dbNames := []string{"main", "history"} + for _, dbName := range dbNames { + // get the create-table SQL statement from the inferred schema + sql := db.Schema.CreateStatement(dbName, true) + log.Debugf("creating table schema for database %q", dbName) - if err := sqlitex.ExecuteTransient(db.writeConn, stmt, nil); err != nil { - return fmt.Errorf("failed to create index: %q: %w", idx, err) + // execute the SQL + if err := sqlitex.ExecuteTransient(db.writeConn, sql, nil); err != nil { + return fmt.Errorf("failed to create schema on database %q: %w", dbName, err) + } + + // create a few indexes + indexes := []string{ + `CREATE INDEX IF NOT EXISTS %sprofile_id_index ON %s (profile)`, + `CREATE INDEX IF NOT EXISTS %sstarted_time_index ON %s (strftime('%%s', started)+0)`, + `CREATE INDEX IF NOT EXISTS %sstarted_ended_time_index ON %s (strftime('%%s', started)+0, strftime('%%s', ended)+0) WHERE ended IS NOT NULL`, + } + for _, idx := range indexes { + name := "" + if dbName != "" { + name = dbName + "." + } + + stmt := fmt.Sprintf(idx, name, db.Schema.Name) + + if err := sqlitex.ExecuteTransient(db.writeConn, stmt, nil); err != nil { + return fmt.Errorf("failed to create index on database %q: %q: %w", dbName, idx, err) + } } } @@ -254,7 +286,7 @@ func (db *Database) CountRows(ctx context.Context) (int, error) { Count int `sqlite:"count"` } - if err := db.Execute(ctx, "SELECT COUNT(*) AS count FROM connections", orm.WithResult(&result)); err != nil { + if err := db.Execute(ctx, "SELECT COUNT(*) AS count FROM (SELECT * FROM main.connections UNION SELECT * from history.connections)", orm.WithResult(&result)); err != nil { return 0, fmt.Errorf("failed to perform query: %w", err) } @@ -265,7 +297,7 @@ func (db *Database) CountRows(ctx context.Context) (int, error) { return result[0].Count, nil } -// Cleanup removes all connections that have ended before threshold. +// Cleanup removes all connections that have ended before threshold from the live database. // // NOTE(ppacher): there is no easy way to get the number of removed // rows other than counting them in a first step. Though, that's @@ -273,7 +305,7 @@ func (db *Database) CountRows(ctx context.Context) (int, error) { func (db *Database) Cleanup(ctx context.Context, threshold time.Time) (int, error) { where := `WHERE ended IS NOT NULL AND datetime(ended) < datetime(:threshold)` - sql := "DELETE FROM connections " + where + ";" + sql := "DELETE FROM main.connections " + where + ";" args := orm.WithNamedArgs(map[string]interface{}{ ":threshold": threshold.UTC().Format(orm.SqliteTimeFormat), @@ -303,6 +335,21 @@ func (db *Database) Cleanup(ctx context.Context, threshold time.Time) (int, erro return result[0].Count, nil } +// RemoveAllHistoryData removes all connections from the history database. +func (db *Database) RemoveAllHistoryData(ctx context.Context) error { + query := fmt.Sprintf("DELETE FROM %s.connections", HistoryDatabase) + return db.ExecuteWrite(ctx, query) +} + +// RemoveHistoryForProfile removes all connections from the history database +// for a given profile ID (source/id). +func (db *Database) RemoveHistoryForProfile(ctx context.Context, profileID string) error { + query := fmt.Sprintf("DELETE FROM %s.connections WHERE profile = :profile", HistoryDatabase) + return db.ExecuteWrite(ctx, query, orm.WithNamedArgs(map[string]any{ + ":profile": profileID, + })) +} + // dumpTo is a simple helper method that dumps all rows stored in the SQLite database // as JSON to w. // Any error aborts dumping rows and is returned. @@ -330,13 +377,76 @@ func (db *Database) dumpTo(ctx context.Context, w io.Writer) error { //nolint:un return enc.Encode(conns) } +// MarkAllHistoryConnectionsEnded marks all connections in the history database as ended. +func (db *Database) MarkAllHistoryConnectionsEnded(ctx context.Context) error { + query := fmt.Sprintf("UPDATE %s.connections SET active = FALSE, ended = :ended WHERE active = TRUE", HistoryDatabase) + + if err := db.ExecuteWrite(ctx, query, orm.WithNamedArgs(map[string]any{ + ":ended": time.Now().Format(orm.SqliteTimeFormat), + })); err != nil { + return err + } + + return nil +} + +// UpdateBandwidth updates bandwidth data for the connection and optionally also writes +// the bandwidth data to the history database. +func (db *Database) UpdateBandwidth(ctx context.Context, enableHistory bool, processKey string, connID string, bytesReceived uint64, bytesSent uint64) error { + data := connID + "-" + processKey + hash := sha256.Sum256([]byte(data)) + dbConnID := hex.EncodeToString(hash[:]) + + params := map[string]any{ + ":id": dbConnID, + } + + parts := []string{} + if bytesReceived != 0 { + parts = append(parts, "bytes_received = :bytes_received") + params[":bytes_received"] = bytesReceived + } + + if bytesSent != 0 { + parts = append(parts, "bytes_sent = :bytes_sent") + params[":bytes_sent"] = bytesSent + } + + updateSet := strings.Join(parts, ", ") + + updateStmts := []string{ + fmt.Sprintf(`UPDATE %s.connections SET %s WHERE id = :id`, LiveDatabase, updateSet), + } + + if enableHistory { + updateStmts = append(updateStmts, + fmt.Sprintf(`UPDATE %s.connections SET %s WHERE id = :id`, HistoryDatabase, updateSet), + ) + } + + merr := new(multierror.Error) + for _, stmt := range updateStmts { + if err := db.ExecuteWrite(ctx, stmt, orm.WithNamedArgs(params)); err != nil { + merr.Errors = append(merr.Errors, err) + } + } + + return merr.ErrorOrNil() +} + // Save inserts the connection conn into the SQLite database. If conn // already exists the table row is updated instead. // // Save uses the database write connection instead of relying on the // connection pool. -func (db *Database) Save(ctx context.Context, conn Conn) error { - connMap, err := orm.ToParamMap(ctx, conn, "", orm.DefaultEncodeConfig) +func (db *Database) Save(ctx context.Context, conn Conn, enableHistory bool) error { + // convert the connection to a param map where each key is already translated + // to the sql column name. We also skip bytes_received and bytes_sent since those + // will be updated independently from the connection object. + connMap, err := orm.ToParamMap(ctx, conn, "", orm.DefaultEncodeConfig, []string{ + "bytes_received", + "bytes_sent", + }) if err != nil { return fmt.Errorf("failed to encode connection for SQL: %w", err) } @@ -367,26 +477,35 @@ func (db *Database) Save(ctx context.Context, conn Conn) error { // TODO(ppacher): make sure this one can be cached to speed up inserting // and save some CPU cycles for the user - sql := fmt.Sprintf( - `INSERT INTO connections (%s) - VALUES(%s) - ON CONFLICT(id) DO UPDATE SET - %s - `, - strings.Join(columns, ", "), - strings.Join(placeholders, ", "), - strings.Join(updateSets, ", "), - ) + dbNames := []DatabaseName{LiveDatabase} - if err := sqlitex.Execute(db.writeConn, sql, &sqlitex.ExecOptions{ - Named: values, - ResultFunc: func(stmt *sqlite.Stmt) error { - log.Errorf("netquery: got result statement with %d columns", stmt.ColumnCount()) - return nil - }, - }); err != nil { - log.Errorf("netquery: failed to execute:\n\t%q\n\treturned error was: %s\n\tparameters: %+v", sql, err, values) - return err + if enableHistory { + dbNames = append(dbNames, HistoryDatabase) + } + + for _, dbName := range dbNames { + sql := fmt.Sprintf( + `INSERT INTO %s.connections (%s) + VALUES(%s) + ON CONFLICT(id) DO UPDATE SET + %s + `, + dbName, + strings.Join(columns, ", "), + strings.Join(placeholders, ", "), + strings.Join(updateSets, ", "), + ) + + if err := sqlitex.Execute(db.writeConn, sql, &sqlitex.ExecOptions{ + Named: values, + ResultFunc: func(stmt *sqlite.Stmt) error { + log.Errorf("netquery: got result statement with %d columns", stmt.ColumnCount()) + return nil + }, + }); err != nil { + log.Errorf("netquery: failed to execute:\n\t%q\n\treturned error was: %s\n\tparameters: %+v", sql, err, values) + return err + } } return nil diff --git a/netquery/manager.go b/netquery/manager.go index 6599d6197..c49aa5c2e 100644 --- a/netquery/manager.go +++ b/netquery/manager.go @@ -25,7 +25,22 @@ type ( // insert or an update. // The ID of Conn is unique and can be trusted to never collide with other // connections of the save device. - Save(context.Context, Conn) error + Save(context.Context, Conn, bool) error + + // MarkAllHistoryConnectionsEnded marks all active connections in the history + // database as ended NOW. + MarkAllHistoryConnectionsEnded(context.Context) error + + // RemoveAllHistoryData removes all connections from the history database. + RemoveAllHistoryData(context.Context) error + + // RemoveHistoryForProfile removes all connections from the history database. + // for a given profile ID (source/id) + RemoveHistoryForProfile(context.Context, string) error + + // UpdateBandwidth updates bandwidth data for the connection and optionally also writes + // the bandwidth data to the history database. + UpdateBandwidth(ctx context.Context, enableHistory bool, processKey string, connID string, bytesReceived uint64, bytesSent uint64) error } // Manager handles new and updated network.Connections feeds and persists them @@ -98,9 +113,10 @@ func (mng *Manager) HandleFeed(ctx context.Context, feed <-chan *network.Connect continue } - log.Tracef("netquery: updating connection %s", conn.ID) + // DEBUG: + // log.Tracef("netquery: updating connection %s", conn.ID) - if err := mng.store.Save(ctx, *model); err != nil { + if err := mng.store.Save(ctx, *model, conn.HistoryEnabled); err != nil { log.Errorf("netquery: failed to save connection %s in sqlite database: %s", conn.ID, err) continue @@ -158,7 +174,9 @@ func convertConnection(conn *network.Connection) (*Conn, error) { IPProtocol: conn.IPProtocol, LocalIP: conn.LocalIP.String(), LocalPort: conn.LocalPort, - Verdict: conn.Verdict.Firewall, // TODO: Expose both Worst and Firewall verdicts. + FirewallVerdict: conn.Verdict.Firewall, + ActiveVerdict: conn.Verdict.Active, + WorstVerdict: conn.Verdict.Worst, Started: time.Unix(conn.Started, 0), Tunneled: conn.Tunneled, Encrypted: conn.Encrypted, @@ -250,7 +268,7 @@ func convertConnection(conn *network.Connection) (*Conn, error) { } func genConnID(conn *network.Connection) string { - data := conn.ID + "-" + time.Unix(conn.Started, 0).String() + data := conn.ID + "-" + conn.Process().GetID() hash := sha256.Sum256([]byte(data)) return hex.EncodeToString(hash[:]) } diff --git a/netquery/module_api.go b/netquery/module_api.go index 4cb024628..344f93913 100644 --- a/netquery/module_api.go +++ b/netquery/module_api.go @@ -2,39 +2,58 @@ package netquery import ( "context" + "encoding/json" "fmt" + "net/http" "time" + "github.com/hashicorp/go-multierror" + "github.com/safing/portbase/api" "github.com/safing/portbase/config" "github.com/safing/portbase/database" "github.com/safing/portbase/database/query" "github.com/safing/portbase/log" "github.com/safing/portbase/modules" + "github.com/safing/portbase/modules/subsystems" "github.com/safing/portbase/runtime" "github.com/safing/portmaster/network" ) +// DefaultModule is the default netquery module. +var DefaultModule *module + type module struct { *modules.Module - db *database.Interface - sqlStore *Database - mng *Manager - feed chan *network.Connection + Store *Database + + db *database.Interface + mng *Manager + feed chan *network.Connection } func init() { - m := new(module) - m.Module = modules.Register( + DefaultModule = new(module) + + DefaultModule.Module = modules.Register( "netquery", - m.prepare, - m.start, - m.stop, + DefaultModule.prepare, + DefaultModule.start, + DefaultModule.stop, "api", "network", "database", ) + + subsystems.Register( + "history", + "Network History", + "Keep Network History Data", + DefaultModule.Module, + "config:history/", + nil, + ) } func (m *module) prepare() error { @@ -45,12 +64,12 @@ func (m *module) prepare() error { Internal: true, }) - m.sqlStore, err = NewInMemory() + m.Store, err = NewInMemory() if err != nil { return fmt.Errorf("failed to create in-memory database: %w", err) } - m.mng, err = NewManager(m.sqlStore, "netquery/data/", runtime.DefaultRegistry) + m.mng, err = NewManager(m.Store, "netquery/data/", runtime.DefaultRegistry) if err != nil { return fmt.Errorf("failed to create manager: %w", err) } @@ -58,12 +77,12 @@ func (m *module) prepare() error { m.feed = make(chan *network.Connection, 1000) queryHander := &QueryHandler{ - Database: m.sqlStore, + Database: m.Store, IsDevMode: config.Concurrent.GetAsBool(config.CfgDevModeKey, false), } chartHandler := &ChartHandler{ - Database: m.sqlStore, + Database: m.Store, } if err := api.RegisterEndpoint(api.Endpoint{ @@ -92,6 +111,56 @@ func (m *module) prepare() error { return fmt.Errorf("failed to register API endpoint: %w", err) } + if err := api.RegisterEndpoint(api.Endpoint{ + Path: "netquery/history/clear", + MimeType: "application/json", + Read: api.PermitUser, + Write: api.PermitUser, + BelongsTo: m.Module, + HandlerFunc: func(w http.ResponseWriter, r *http.Request) { + var body struct { + ProfileIDs []string `json:"profileIDs"` + } + + dec := json.NewDecoder(r.Body) + dec.DisallowUnknownFields() + + if err := dec.Decode(&body); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + if len(body.ProfileIDs) == 0 { + if err := m.mng.store.RemoveAllHistoryData(r.Context()); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + + return + } + } else { + merr := new(multierror.Error) + for _, profileID := range body.ProfileIDs { + if err := m.mng.store.RemoveHistoryForProfile(r.Context(), profileID); err != nil { + merr.Errors = append(merr.Errors, fmt.Errorf("failed to clear history for %q: %w", profileID, err)) + } else { + log.Infof("netquery: successfully cleared history for %s", profileID) + } + } + + if err := merr.ErrorOrNil(); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + + return + } + } + + w.WriteHeader(http.StatusNoContent) + }, + Name: "Remove connections from profile history", + Description: "Remove all connections from the history database for one or more profiles", + }); err != nil { + return fmt.Errorf("failed to register API endpoint: %w", err) + } + return nil } @@ -139,7 +208,7 @@ func (m *module) start() error { return nil case <-time.After(10 * time.Second): threshold := time.Now().Add(-network.DeleteConnsAfterEndedThreshold) - count, err := m.sqlStore.Cleanup(ctx, threshold) + count, err := m.Store.Cleanup(ctx, threshold) if err != nil { log.Errorf("netquery: failed to count number of rows in memory: %s", err) } else { @@ -153,7 +222,7 @@ func (m *module) start() error { // the runtime database. // Only expose in development mode. if config.GetAsBool(config.CfgDevModeKey, false)() { - _, err := NewRuntimeQueryRunner(m.sqlStore, "netquery/query/", runtime.DefaultRegistry) + _, err := NewRuntimeQueryRunner(m.Store, "netquery/query/", runtime.DefaultRegistry) if err != nil { return fmt.Errorf("failed to set up runtime SQL query runner: %w", err) } @@ -163,5 +232,16 @@ func (m *module) start() error { } func (m *module) stop() error { + // we don't use m.Module.Ctx here because it is already cancelled when stop is called. + // just give the clean up 1 minute to happen and abort otherwise. + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + if err := m.mng.store.MarkAllHistoryConnectionsEnded(ctx); err != nil { + // handle the error by just logging it. There's not much we can do here + // and returning an error to the module system doesn't help much as well... + log.Errorf("netquery: failed to mark connections in history database as ended: %s", err) + } + return nil } diff --git a/netquery/orm/encoder.go b/netquery/orm/encoder.go index 7961f088c..ef86b842d 100644 --- a/netquery/orm/encoder.go +++ b/netquery/orm/encoder.go @@ -6,6 +6,7 @@ import ( "reflect" "time" + "golang.org/x/exp/slices" "zombiezen.com/go/sqlite" ) @@ -22,7 +23,7 @@ type ( // ToParamMap returns a map that contains the sqlite compatible value of each struct field of // r using the sqlite column name as a map key. It either uses the name of the // exported struct field or the value of the "sqlite" tag. -func ToParamMap(ctx context.Context, r interface{}, keyPrefix string, cfg EncodeConfig) (map[string]interface{}, error) { +func ToParamMap(ctx context.Context, r interface{}, keyPrefix string, cfg EncodeConfig, skipFields []string) (map[string]interface{}, error) { // make sure we work on a struct type val := reflect.Indirect(reflect.ValueOf(r)) if val.Kind() != reflect.Struct { @@ -45,6 +46,10 @@ func ToParamMap(ctx context.Context, r interface{}, keyPrefix string, cfg Encode return nil, fmt.Errorf("failed to get column definition for %s: %w", fieldType.Name, err) } + if slices.Contains(skipFields, colDef.Name) { + continue + } + x, found, err := runEncodeHooks( colDef, fieldType.Type, diff --git a/netquery/orm/encoder_test.go b/netquery/orm/encoder_test.go index e5142962d..d0d3c0392 100644 --- a/netquery/orm/encoder_test.go +++ b/netquery/orm/encoder_test.go @@ -119,7 +119,7 @@ func TestEncodeAsMap(t *testing.T) { //nolint:tparallel for idx := range cases { //nolint:paralleltest c := cases[idx] t.Run(c.Desc, func(t *testing.T) { - res, err := ToParamMap(ctx, c.Input, "", DefaultEncodeConfig) + res, err := ToParamMap(ctx, c.Input, "", DefaultEncodeConfig, nil) assert.NoError(t, err) assert.Equal(t, c.Expected, res) }) diff --git a/netquery/orm/query_runner.go b/netquery/orm/query_runner.go index 55bafe30c..135a29f61 100644 --- a/netquery/orm/query_runner.go +++ b/netquery/orm/query_runner.go @@ -143,7 +143,23 @@ func RunQuery(ctx context.Context, conn *sqlite.Conn, sql string, modifiers ...Q currentField := reflect.New(valElemType) if err := DecodeStmt(ctx, &args.Schema, stmt, currentField.Interface(), args.DecodeConfig); err != nil { - return err + resultDump := make(map[string]any) + + for colIdx := 0; colIdx < stmt.ColumnCount(); colIdx++ { + name := stmt.ColumnName(colIdx) + + switch stmt.ColumnType(colIdx) { //nolint:exhaustive // TODO: handle type BLOB? + case sqlite.TypeText: + resultDump[name] = stmt.ColumnText(colIdx) + case sqlite.TypeFloat: + resultDump[name] = stmt.ColumnFloat(colIdx) + case sqlite.TypeInteger: + resultDump[name] = stmt.ColumnInt(colIdx) + case sqlite.TypeNull: + resultDump[name] = "" + } + } + return fmt.Errorf("%w: %+v", err, resultDump) } sliceVal = reflect.Append(sliceVal, reflect.Indirect(currentField)) diff --git a/netquery/orm/schema_builder.go b/netquery/orm/schema_builder.go index 508b7b186..6aba2a1f7 100644 --- a/netquery/orm/schema_builder.go +++ b/netquery/orm/schema_builder.go @@ -8,6 +8,8 @@ import ( "strings" "zombiezen.com/go/sqlite" + + "github.com/safing/portbase/log" ) var errSkipStructField = errors.New("struct field should be skipped") @@ -25,6 +27,7 @@ var ( TagTypePrefixVarchar = "varchar" TagTypeBlob = "blob" TagTypeFloat = "float" + TagTypePrefixDefault = "default=" ) var sqlTypeMap = map[sqlite.ColumnType]string{ @@ -52,6 +55,7 @@ type ( AutoIncrement bool UnixNano bool IsTime bool + Default any } ) @@ -66,12 +70,17 @@ func (ts TableSchema) GetColumnDef(name string) *ColumnDef { } // CreateStatement build the CREATE SQL statement for the table. -func (ts TableSchema) CreateStatement(ifNotExists bool) string { +func (ts TableSchema) CreateStatement(databaseName string, ifNotExists bool) string { sql := "CREATE TABLE" if ifNotExists { sql += " IF NOT EXISTS" } - sql += " " + ts.Name + " ( " + name := ts.Name + if databaseName != "" { + name = databaseName + "." + ts.Name + } + + sql += " " + name + " ( " for idx, col := range ts.Columns { sql += col.AsSQL() @@ -100,6 +109,21 @@ func (def ColumnDef) AsSQL() string { if def.AutoIncrement { sql += " AUTOINCREMENT" } + if def.Default != nil { + sql += " DEFAULT " + switch def.Type { //nolint:exhaustive // TODO: handle types BLOB, NULL? + case sqlite.TypeFloat: + sql += strconv.FormatFloat(def.Default.(float64), 'b', 0, 64) //nolint:forcetypeassert + case sqlite.TypeInteger: + sql += strconv.FormatInt(def.Default.(int64), 10) //nolint:forcetypeassert + case sqlite.TypeText: + sql += fmt.Sprintf("%q", def.Default.(string)) //nolint:forcetypeassert + default: + log.Errorf("unsupported default value: %q %q", def.Type, def.Default) + sql = strings.TrimSuffix(sql, " DEFAULT ") + } + sql += " " + } if !def.Nullable { sql += " NOT NULL" } @@ -155,7 +179,7 @@ func getColumnDef(fieldType reflect.StructField) (*ColumnDef, error) { kind := normalizeKind(ft.Kind()) switch kind { //nolint:exhaustive - case reflect.Int: + case reflect.Int, reflect.Uint: def.Type = sqlite.TypeInteger case reflect.Float64: @@ -232,6 +256,30 @@ func applyStructFieldTag(fieldType reflect.StructField, def *ColumnDef) error { def.Length = int(length) } + if strings.HasPrefix(k, TagTypePrefixDefault) { + defaultValue := strings.TrimPrefix(k, TagTypePrefixDefault) + switch def.Type { //nolint:exhaustive + case sqlite.TypeFloat: + fv, err := strconv.ParseFloat(defaultValue, 64) + if err != nil { + return fmt.Errorf("failed to parse default value as float %q: %w", defaultValue, err) + } + def.Default = fv + case sqlite.TypeInteger: + fv, err := strconv.ParseInt(defaultValue, 10, 0) + if err != nil { + return fmt.Errorf("failed to parse default value as int %q: %w", defaultValue, err) + } + def.Default = fv + case sqlite.TypeText: + def.Default = defaultValue + case sqlite.TypeBlob: + return fmt.Errorf("default values for TypeBlob not yet supported") + default: + return fmt.Errorf("failed to apply default value for unknown sqlite column type %s", def.Type) + } + } + } } } diff --git a/netquery/orm/schema_builder_test.go b/netquery/orm/schema_builder_test.go index 734da9814..fdd43ec79 100644 --- a/netquery/orm/schema_builder_test.go +++ b/netquery/orm/schema_builder_test.go @@ -22,14 +22,14 @@ func TestSchemaBuilder(t *testing.T) { Int *int `sqlite:",not-null"` Float interface{} `sqlite:",float,nullable"` }{}, - `CREATE TABLE Simple ( id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, text TEXT, Int INTEGER NOT NULL, Float REAL );`, + `CREATE TABLE main.Simple ( id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, text TEXT, Int INTEGER NOT NULL, Float REAL );`, }, { "Varchar", struct { S string `sqlite:",varchar(10)"` }{}, - `CREATE TABLE Varchar ( S VARCHAR(10) NOT NULL );`, + `CREATE TABLE main.Varchar ( S VARCHAR(10) NOT NULL );`, }, } @@ -38,6 +38,6 @@ func TestSchemaBuilder(t *testing.T) { res, err := GenerateTableSchema(c.Name, c.Model) assert.NoError(t, err) - assert.Equal(t, c.ExpectedSQL, res.CreateStatement(false)) + assert.Equal(t, c.ExpectedSQL, res.CreateStatement("main", false)) } } diff --git a/netquery/query.go b/netquery/query.go index 83dbc2172..06b766f68 100644 --- a/netquery/query.go +++ b/netquery/query.go @@ -14,6 +14,15 @@ import ( "github.com/safing/portmaster/netquery/orm" ) +// DatabaseName is a database name constant. +type DatabaseName string + +// Databases. +const ( + LiveDatabase = DatabaseName("main") + HistoryDatabase = DatabaseName("history") +) + // Collection of Query and Matcher types. // NOTE: whenever adding support for new operators make sure // to update UnmarshalJSON as well. @@ -48,11 +57,19 @@ type ( Distinct bool `json:"distinct"` } + Min struct { + Condition *Query `json:"condition,omitempty"` + Field string `json:"field"` + As string `json:"as"` + Distinct bool `json:"distinct"` + } + Select struct { Field string `json:"field"` Count *Count `json:"$count,omitempty"` Sum *Sum `json:"$sum,omitempty"` - Distinct *string `json:"$distinct"` + Min *Min `json:"$min,omitempty"` + Distinct *string `json:"$distinct,omitempty"` } Selects []Select @@ -68,6 +85,9 @@ type ( OrderBy OrderBys `json:"orderBy"` GroupBy []string `json:"groupBy"` TextSearch *TextSearch `json:"textSearch"` + // A list of databases to query. If left empty, + // both, the LiveDatabase and the HistoryDatabase are queried + Databases []DatabaseName `json:"databases"` Pagination @@ -457,6 +477,7 @@ func (sel *Select) UnmarshalJSON(blob []byte) error { Field string `json:"field"` Count *Count `json:"$count"` Sum *Sum `json:"$sum"` + Min *Min `json:"$min"` Distinct *string `json:"$distinct"` } @@ -468,12 +489,23 @@ func (sel *Select) UnmarshalJSON(blob []byte) error { sel.Field = res.Field sel.Distinct = res.Distinct sel.Sum = res.Sum + sel.Min = res.Min if sel.Count != nil && sel.Count.As != "" { if !charOnlyRegexp.MatchString(sel.Count.As) { return fmt.Errorf("invalid characters in $count.as, value must match [a-zA-Z]+") } } + if sel.Sum != nil && sel.Sum.As != "" { + if !charOnlyRegexp.MatchString(sel.Sum.As) { + return fmt.Errorf("invalid characters in $sum.as, value must match [a-zA-Z]+") + } + } + if sel.Min != nil && sel.Min.As != "" { + if !charOnlyRegexp.MatchString(sel.Min.As) { + return fmt.Errorf("invalid characters in $min.as, value must match [a-zA-Z]+") + } + } return nil } diff --git a/netquery/query_handler.go b/netquery/query_handler.go index 599c71ec9..e555965d1 100644 --- a/netquery/query_handler.go +++ b/netquery/query_handler.go @@ -12,6 +12,8 @@ import ( "strings" "time" + "golang.org/x/exp/slices" + "github.com/safing/portbase/log" "github.com/safing/portmaster/netquery/orm" ) @@ -152,13 +154,7 @@ func (req *QueryRequestPayload) generateSQL(ctx context.Context, schema *orm.Tab return "", nil, fmt.Errorf("generating where clause: %w", err) } - if req.paramMap == nil { - req.paramMap = make(map[string]interface{}) - } - - for key, val := range paramMap { - req.paramMap[key] = val - } + req.mergeParams(paramMap) if req.TextSearch != nil { textClause, textParams, err := req.TextSearch.toSQLConditionClause(ctx, schema, "", orm.DefaultEncodeConfig) @@ -173,9 +169,7 @@ func (req *QueryRequestPayload) generateSQL(ctx context.Context, schema *orm.Tab whereClause += textClause - for key, val := range textParams { - req.paramMap[key] = val - } + req.mergeParams(textParams) } } @@ -190,11 +184,24 @@ func (req *QueryRequestPayload) generateSQL(ctx context.Context, schema *orm.Tab } selectClause := req.generateSelectClause() - query := `SELECT ` + selectClause + ` FROM connections` + if whereClause != "" { - query += " WHERE " + whereClause + whereClause = "WHERE " + whereClause + } + + if len(req.Databases) == 0 { + req.Databases = []DatabaseName{LiveDatabase, HistoryDatabase} } + sources := make([]string, len(req.Databases)) + for idx, db := range req.Databases { + sources[idx] = fmt.Sprintf("SELECT * FROM %s.connections %s", db, whereClause) + } + + source := strings.Join(sources, " UNION ") + + query := `SELECT ` + selectClause + ` FROM ( ` + source + ` ) ` + query += " " + groupByClause + " " + orderByClause + " " + req.Pagination.toSQLLimitOffsetClause() return strings.TrimSpace(query), req.paramMap, nil @@ -203,6 +210,7 @@ func (req *QueryRequestPayload) generateSQL(ctx context.Context, schema *orm.Tab func (req *QueryRequestPayload) prepareSelectedFields(ctx context.Context, schema *orm.TableSchema) error { for idx, s := range req.Select { var field string + switch { case s.Count != nil: field = s.Count.Field @@ -211,6 +219,12 @@ func (req *QueryRequestPayload) prepareSelectedFields(ctx context.Context, schem case s.Sum != nil: // field is not used in case of $sum field = "*" + case s.Min != nil: + if s.Min.Field != "" { + field = s.Min.Field + } else { + field = "*" + } default: field = s.Field } @@ -251,13 +265,40 @@ func (req *QueryRequestPayload) prepareSelectedFields(ctx context.Context, schem return fmt.Errorf("in $sum: %w", err) } - req.paramMap = params + req.mergeParams(params) req.selectedFields = append( req.selectedFields, fmt.Sprintf("SUM(%s) AS %s", clause, s.Sum.As), ) req.whitelistedFields = append(req.whitelistedFields, s.Sum.As) + case s.Min != nil: + if s.Min.As == "" { + return fmt.Errorf("missing 'as' for $min") + } + + var ( + clause string + params map[string]any + ) + + if s.Min.Field != "" { + clause = field + } else { + var err error + clause, params, err = s.Min.Condition.toSQLWhereClause(ctx, fmt.Sprintf("sel%d", idx), schema, orm.DefaultEncodeConfig) + if err != nil { + return fmt.Errorf("in $min: %w", err) + } + } + + req.mergeParams(params) + req.selectedFields = append( + req.selectedFields, + fmt.Sprintf("MIN(%s) AS %s", clause, s.Min.As), + ) + req.whitelistedFields = append(req.whitelistedFields, s.Min.As) + case s.Distinct != nil: req.selectedFields = append(req.selectedFields, fmt.Sprintf("DISTINCT %s", colName)) req.whitelistedFields = append(req.whitelistedFields, colName) @@ -270,6 +311,16 @@ func (req *QueryRequestPayload) prepareSelectedFields(ctx context.Context, schem return nil } +func (req *QueryRequestPayload) mergeParams(params map[string]any) { + if req.paramMap == nil { + req.paramMap = make(map[string]any) + } + + for key, value := range params { + req.paramMap[key] = value + } +} + func (req *QueryRequestPayload) generateGroupByClause(schema *orm.TableSchema) (string, error) { if len(req.GroupBy) == 0 { return "", nil @@ -332,16 +383,12 @@ func (req *QueryRequestPayload) validateColumnName(schema *orm.TableSchema, fiel return colDef.Name, nil } - for _, selected := range req.whitelistedFields { - if field == selected { - return field, nil - } + if slices.Contains(req.whitelistedFields, field) { + return field, nil } - for _, selected := range req.selectedFields { - if field == selected { - return field, nil - } + if slices.Contains(req.selectedFields, field) { + return field, nil } return "", fmt.Errorf("column name %q not allowed", field) diff --git a/network/clean.go b/network/clean.go index a538b7f53..f31031424 100644 --- a/network/clean.go +++ b/network/clean.go @@ -78,7 +78,8 @@ func cleanConnections() (activePIDs map[int]struct{}) { } case conn.Ended < deleteOlderThan: // Step 3: delete - log.Tracef("network.clean: deleted %s (ended at %s)", conn.DatabaseKey(), time.Unix(conn.Ended, 0)) + // DEBUG: + // log.Tracef("network.clean: deleted %s (ended at %s)", conn.DatabaseKey(), time.Unix(conn.Ended, 0)) conn.delete() } diff --git a/network/connection.go b/network/connection.go index 63182ca2a..972a5c5b9 100644 --- a/network/connection.go +++ b/network/connection.go @@ -19,6 +19,8 @@ import ( "github.com/safing/portmaster/process" _ "github.com/safing/portmaster/process/tags" "github.com/safing/portmaster/resolver" + "github.com/safing/spn/access" + "github.com/safing/spn/access/account" "github.com/safing/spn/navigator" ) @@ -173,8 +175,17 @@ type Connection struct { //nolint:maligned // TODO: fix alignment StopTunnel() error } - RecvBytes uint64 - SentBytes uint64 + // HistoryEnabled is set to true when the connection should be persisted + // in the history database. + HistoryEnabled bool + // BanwidthEnabled is set to true if connection bandwidth data should be persisted + // in netquery. + BandwidthEnabled bool + + // BytesReceived holds the observed received bytes of the connection. + BytesReceived uint64 + // BytesSent holds the observed sent bytes of the connection. + BytesSent uint64 // pkgQueue is used to serialize packet handling for a single // connection and is served by the connections packetHandler. @@ -326,6 +337,10 @@ func NewConnectionFromDNSRequest(ctx context.Context, fqdn string, cnames []stri // Inherit internal status of profile. if localProfile := proc.Profile().LocalProfile(); localProfile != nil { dnsConn.Internal = localProfile.Internal + + if err := dnsConn.updateFeatures(); err != nil { + log.Tracer(ctx).Warningf("network: failed to check for enabled features: %s", err) + } } // DNS Requests are saved by the nameserver depending on the result of the @@ -364,6 +379,10 @@ func NewConnectionFromExternalDNSRequest(ctx context.Context, fqdn string, cname // Inherit internal status of profile. if localProfile := remoteHost.Profile().LocalProfile(); localProfile != nil { dnsConn.Internal = localProfile.Internal + + if err := dnsConn.updateFeatures(); err != nil { + log.Tracer(ctx).Warningf("network: failed to check for enabled features: %s", err) + } } // DNS Requests are saved by the nameserver depending on the result of the @@ -374,6 +393,8 @@ func NewConnectionFromExternalDNSRequest(ctx context.Context, fqdn string, cname return dnsConn, nil } +var tooOldTimestamp = time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC).Unix() + // NewIncompleteConnection creates a new incomplete connection with only minimal information. func NewIncompleteConnection(pkt packet.Packet) *Connection { info := pkt.Info() @@ -390,6 +411,12 @@ func NewIncompleteConnection(pkt packet.Packet) *Connection { dataComplete: abool.NewBool(false), } + // Bullshit check Started timestamp. + if conn.Started < tooOldTimestamp { + // Fix timestamp, use current time as fallback. + conn.Started = time.Now().Unix() + } + // Save connection to internal state in order to mitigate creation of // duplicates. Do not propagate yet, as data is not yet complete. conn.UpdateMeta() @@ -420,7 +447,12 @@ func (conn *Connection) GatherConnectionInfo(pkt packet.Packet) (err error) { // Inherit internal status of profile. if localProfile := conn.process.Profile().LocalProfile(); localProfile != nil { conn.Internal = localProfile.Internal + + if err := conn.updateFeatures(); err != nil { + log.Tracer(pkt.Ctx()).Warningf("network: failed to check for enabled features: %s", err) + } } + } else { conn.process = nil if pkt.InfoOnly() { @@ -533,6 +565,31 @@ func (conn *Connection) SetLocalIP(ip net.IP) { conn.LocalIPScope = netutils.GetIPScope(ip) } +// updateFeatures checks which connection related features may be used and sets +// the flags accordingly. +func (conn *Connection) updateFeatures() error { + // Get user. + user, err := access.GetUser() + if err != nil { + return err + } + + // Check if history may be used and if it is enabled for this application. + if user.MayUse(account.FeatureHistory) { + lProfile := conn.Process().Profile() + if lProfile != nil { + conn.HistoryEnabled = lProfile.HistoryEnabled() + } + } + + // Check if bandwidth visibility may be used. + if user.MayUse(account.FeatureBWVis) { + conn.BandwidthEnabled = true + } + + return nil +} + // AcceptWithContext accepts the connection. func (conn *Connection) AcceptWithContext(reason, reasonOptionKey string, ctx interface{}) { if !conn.SetVerdict(VerdictAccept, reason, reasonOptionKey, ctx) { diff --git a/network/packet/bandwidth.go b/network/packet/bandwidth.go index c65ac0854..c2ce6a01f 100644 --- a/network/packet/bandwidth.go +++ b/network/packet/bandwidth.go @@ -4,10 +4,10 @@ import "fmt" // BandwidthUpdate holds an update to the seen bandwidth of a connection. type BandwidthUpdate struct { - ConnID string - RecvBytes uint64 - SentBytes uint64 - Method BandwidthUpdateMethod + ConnID string + BytesReceived uint64 + BytesSent uint64 + Method BandwidthUpdateMethod } // BandwidthUpdateMethod defines how the bandwidth data of a bandwidth update should be interpreted. @@ -20,7 +20,7 @@ const ( ) func (bu *BandwidthUpdate) String() string { - return fmt.Sprintf("%s: %dB recv | %dB sent [%s]", bu.ConnID, bu.RecvBytes, bu.SentBytes, bu.Method) + return fmt.Sprintf("%s: %dB recv | %dB sent [%s]", bu.ConnID, bu.BytesReceived, bu.BytesSent, bu.Method) } func (bum BandwidthUpdateMethod) String() string { diff --git a/process/process.go b/process/process.go index 3f2779f98..9f0acc7e4 100644 --- a/process/process.go +++ b/process/process.go @@ -313,6 +313,13 @@ func loadProcess(ctx context.Context, key string, pInfo *processInfo.Process) (* return process, nil } +// GetID returns the key that is used internally to identify the process. +// The ID consists of the PID and the start time of the process as reported by +// the system. +func (p *Process) GetID() string { + return p.processKey +} + // Builds a unique identifier for a processes. func getProcessKey(pid int32, createdTime int64) string { return fmt.Sprintf("%d-%d", pid, createdTime) diff --git a/profile/config.go b/profile/config.go index 416de06b7..ff3e072c8 100644 --- a/profile/config.go +++ b/profile/config.go @@ -6,6 +6,7 @@ import ( "github.com/safing/portbase/config" "github.com/safing/portmaster/profile/endpoints" "github.com/safing/portmaster/status" + "github.com/safing/spn/access/account" "github.com/safing/spn/navigator" ) @@ -103,7 +104,13 @@ var ( cfgOptionDisableAutoPermit config.IntOption // security level option cfgOptionDisableAutoPermitOrder = 65 - // Setting "Permanent Verdicts" at order 96. + // Setting "Permanent Verdicts" at order 80. + + // Network History. + + CfgOptionEnableHistoryKey = "history/enable" + cfgOptionEnableHistory config.BoolOption + cfgOptionEnableHistoryOrder = 96 // Setting "Enable SPN" at order 128. @@ -239,6 +246,27 @@ func registerConfiguration() error { //nolint:maintidx cfgOptionDisableAutoPermit = config.Concurrent.GetAsInt(CfgOptionDisableAutoPermitKey, int64(status.SecurityLevelsAll)) cfgIntOptions[CfgOptionDisableAutoPermitKey] = cfgOptionDisableAutoPermit + // Enable History + err = config.Register(&config.Option{ + Name: "Enable Connection History", + Key: CfgOptionEnableHistoryKey, + Description: "Whether or not to save connections to the history database", + OptType: config.OptTypeBool, + ReleaseLevel: config.ReleaseLevelStable, + ExpertiseLevel: config.ExpertiseLevelExpert, + DefaultValue: false, + Annotations: config.Annotations{ + config.DisplayOrderAnnotation: cfgOptionEnableHistoryOrder, + config.CategoryAnnotation: "History", + config.RequiresFeatureID: account.FeatureHistory, + }, + }) + if err != nil { + return err + } + cfgOptionEnableHistory = config.Concurrent.GetAsBool(CfgOptionEnableHistoryKey, false) + cfgBoolOptions[CfgOptionEnableHistoryKey] = cfgOptionEnableHistory + rulesHelp := strings.ReplaceAll(`Rules are checked from top to bottom, stopping after the first match. They can match: - By address: "192.168.0.1" diff --git a/profile/profile-layered.go b/profile/profile-layered.go index b2f7850b8..5380aca89 100644 --- a/profile/profile-layered.go +++ b/profile/profile-layered.go @@ -49,6 +49,7 @@ type LayeredProfile struct { DomainHeuristics config.BoolOption `json:"-"` UseSPN config.BoolOption `json:"-"` SPNRoutingAlgorithm config.StringOption `json:"-"` + HistoryEnabled config.BoolOption `json:"-"` } // NewLayeredProfile returns a new layered profile based on the given local profile. @@ -120,6 +121,10 @@ func NewLayeredProfile(localProfile *Profile) *LayeredProfile { CfgOptionRoutingAlgorithmKey, cfgOptionRoutingAlgorithm, ) + lp.HistoryEnabled = lp.wrapBoolOption( + CfgOptionEnableHistoryKey, + cfgOptionEnableHistory, + ) lp.LayerIDs = append(lp.LayerIDs, localProfile.ScopedID()) lp.layers = append(lp.layers, localProfile) diff --git a/profile/profile.go b/profile/profile.go index 1fa12ff81..2d0eb9c45 100644 --- a/profile/profile.go +++ b/profile/profile.go @@ -136,6 +136,7 @@ type Profile struct { //nolint:maligned // not worth the effort filterListIDs []string spnUsagePolicy endpoints.Endpoints spnExitHubPolicy endpoints.Endpoints + enableHistory bool // Lifecycle Management outdated *abool.AtomicBool @@ -233,6 +234,11 @@ func (profile *Profile) parseConfig() error { } } + enableHistory, ok := profile.configPerspective.GetAsBool(CfgOptionEnableHistoryKey) + if ok { + profile.enableHistory = enableHistory + } + return lastErr } @@ -315,6 +321,11 @@ func (profile *Profile) IsOutdated() bool { return profile.outdated.IsSet() } +// HistoryEnabled returns true if connection history is enabled for the profile. +func (profile *Profile) HistoryEnabled() bool { + return profile.enableHistory +} + // GetEndpoints returns the endpoint list of the profile. This functions // requires the profile to be read locked. func (profile *Profile) GetEndpoints() endpoints.Endpoints {