initial commit

This commit is contained in:
tumillanino
2025-11-12 18:34:08 +11:00
commit 2fed8f268e
585 changed files with 161655 additions and 0 deletions

View File

@@ -0,0 +1,341 @@
package bluez
import (
"context"
"errors"
"fmt"
"strconv"
"time"
"github.com/AvengeMedia/danklinux/internal/errdefs"
"github.com/AvengeMedia/danklinux/internal/log"
"github.com/godbus/dbus/v5"
)
const (
bluezService = "org.bluez"
agentManagerPath = "/org/bluez"
agentManagerIface = "org.bluez.AgentManager1"
agent1Iface = "org.bluez.Agent1"
device1Iface = "org.bluez.Device1"
agentPath = "/com/danklinux/bluez/agent"
agentCapability = "KeyboardDisplay"
)
const introspectXML = `
<node>
<interface name="org.bluez.Agent1">
<method name="Release"/>
<method name="RequestPinCode">
<arg direction="in" type="o" name="device"/>
<arg direction="out" type="s" name="pincode"/>
</method>
<method name="RequestPasskey">
<arg direction="in" type="o" name="device"/>
<arg direction="out" type="u" name="passkey"/>
</method>
<method name="DisplayPinCode">
<arg direction="in" type="o" name="device"/>
<arg direction="in" type="s" name="pincode"/>
</method>
<method name="DisplayPasskey">
<arg direction="in" type="o" name="device"/>
<arg direction="in" type="u" name="passkey"/>
<arg direction="in" type="q" name="entered"/>
</method>
<method name="RequestConfirmation">
<arg direction="in" type="o" name="device"/>
<arg direction="in" type="u" name="passkey"/>
</method>
<method name="RequestAuthorization">
<arg direction="in" type="o" name="device"/>
</method>
<method name="AuthorizeService">
<arg direction="in" type="o" name="device"/>
<arg direction="in" type="s" name="uuid"/>
</method>
<method name="Cancel"/>
</interface>
<interface name="org.freedesktop.DBus.Introspectable">
<method name="Introspect">
<arg direction="out" type="s" name="data"/>
</method>
</interface>
</node>`
type BluezAgent struct {
conn *dbus.Conn
broker PromptBroker
}
func NewBluezAgent(broker PromptBroker) (*BluezAgent, error) {
conn, err := dbus.ConnectSystemBus()
if err != nil {
return nil, fmt.Errorf("system bus connection failed: %w", err)
}
agent := &BluezAgent{
conn: conn,
broker: broker,
}
if err := conn.Export(agent, dbus.ObjectPath(agentPath), agent1Iface); err != nil {
conn.Close()
return nil, fmt.Errorf("agent export failed: %w", err)
}
if err := conn.Export(agent, dbus.ObjectPath(agentPath), "org.freedesktop.DBus.Introspectable"); err != nil {
conn.Close()
return nil, fmt.Errorf("introspection export failed: %w", err)
}
mgr := conn.Object(bluezService, dbus.ObjectPath(agentManagerPath))
if err := mgr.Call(agentManagerIface+".RegisterAgent", 0, dbus.ObjectPath(agentPath), agentCapability).Err; err != nil {
conn.Close()
return nil, fmt.Errorf("agent registration failed: %w", err)
}
if err := mgr.Call(agentManagerIface+".RequestDefaultAgent", 0, dbus.ObjectPath(agentPath)).Err; err != nil {
log.Debugf("[BluezAgent] not default agent: %v", err)
}
log.Infof("[BluezAgent] registered at %s with capability %s", agentPath, agentCapability)
return agent, nil
}
func (a *BluezAgent) Close() {
if a.conn == nil {
return
}
mgr := a.conn.Object(bluezService, dbus.ObjectPath(agentManagerPath))
mgr.Call(agentManagerIface+".UnregisterAgent", 0, dbus.ObjectPath(agentPath))
a.conn.Close()
}
func (a *BluezAgent) Release() *dbus.Error {
log.Infof("[BluezAgent] Release called")
return nil
}
func (a *BluezAgent) RequestPinCode(device dbus.ObjectPath) (string, *dbus.Error) {
log.Infof("[BluezAgent] RequestPinCode: device=%s", device)
secrets, err := a.promptFor(device, "pin", []string{"pin"}, nil)
if err != nil {
log.Warnf("[BluezAgent] RequestPinCode failed: %v", err)
return "", a.errorFrom(err)
}
pin := secrets["pin"]
log.Infof("[BluezAgent] RequestPinCode returning PIN (len=%d)", len(pin))
return pin, nil
}
func (a *BluezAgent) RequestPasskey(device dbus.ObjectPath) (uint32, *dbus.Error) {
log.Infof("[BluezAgent] RequestPasskey: device=%s", device)
secrets, err := a.promptFor(device, "passkey", []string{"passkey"}, nil)
if err != nil {
log.Warnf("[BluezAgent] RequestPasskey failed: %v", err)
return 0, a.errorFrom(err)
}
passkey, err := strconv.ParseUint(secrets["passkey"], 10, 32)
if err != nil {
log.Warnf("[BluezAgent] invalid passkey format: %v", err)
return 0, dbus.MakeFailedError(fmt.Errorf("invalid passkey: %w", err))
}
log.Infof("[BluezAgent] RequestPasskey returning: %d", passkey)
return uint32(passkey), nil
}
func (a *BluezAgent) DisplayPinCode(device dbus.ObjectPath, pincode string) *dbus.Error {
log.Infof("[BluezAgent] DisplayPinCode: device=%s, pin=%s", device, pincode)
_, err := a.promptFor(device, "display-pin", []string{}, &pincode)
if err != nil {
log.Warnf("[BluezAgent] DisplayPinCode acknowledgment failed: %v", err)
}
return nil
}
func (a *BluezAgent) DisplayPasskey(device dbus.ObjectPath, passkey uint32, entered uint16) *dbus.Error {
log.Infof("[BluezAgent] DisplayPasskey: device=%s, passkey=%06d, entered=%d", device, passkey, entered)
if entered == 0 {
pk := passkey
_, err := a.promptFor(device, "display-passkey", []string{}, nil)
if err != nil {
log.Warnf("[BluezAgent] DisplayPasskey acknowledgment failed: %v", err)
}
_ = pk
}
return nil
}
func (a *BluezAgent) RequestConfirmation(device dbus.ObjectPath, passkey uint32) *dbus.Error {
log.Infof("[BluezAgent] RequestConfirmation: device=%s, passkey=%06d", device, passkey)
secrets, err := a.promptFor(device, "confirm", []string{"decision"}, nil)
if err != nil {
log.Warnf("[BluezAgent] RequestConfirmation failed: %v", err)
return a.errorFrom(err)
}
if secrets["decision"] != "yes" && secrets["decision"] != "accept" {
log.Debugf("[BluezAgent] RequestConfirmation rejected by user")
return dbus.NewError("org.bluez.Error.Rejected", nil)
}
log.Infof("[BluezAgent] RequestConfirmation accepted")
return nil
}
func (a *BluezAgent) RequestAuthorization(device dbus.ObjectPath) *dbus.Error {
log.Infof("[BluezAgent] RequestAuthorization: device=%s", device)
secrets, err := a.promptFor(device, "authorize", []string{"decision"}, nil)
if err != nil {
log.Warnf("[BluezAgent] RequestAuthorization failed: %v", err)
return a.errorFrom(err)
}
if secrets["decision"] != "yes" && secrets["decision"] != "accept" {
log.Debugf("[BluezAgent] RequestAuthorization rejected by user")
return dbus.NewError("org.bluez.Error.Rejected", nil)
}
log.Infof("[BluezAgent] RequestAuthorization accepted")
return nil
}
func (a *BluezAgent) AuthorizeService(device dbus.ObjectPath, uuid string) *dbus.Error {
log.Infof("[BluezAgent] AuthorizeService: device=%s, uuid=%s", device, uuid)
secrets, err := a.promptFor(device, "authorize-service:"+uuid, []string{"decision"}, nil)
if err != nil {
log.Warnf("[BluezAgent] AuthorizeService failed: %v", err)
return a.errorFrom(err)
}
if secrets["decision"] != "yes" && secrets["decision"] != "accept" {
log.Debugf("[BluezAgent] AuthorizeService rejected by user")
return dbus.NewError("org.bluez.Error.Rejected", nil)
}
log.Infof("[BluezAgent] AuthorizeService accepted")
return nil
}
func (a *BluezAgent) Cancel() *dbus.Error {
log.Infof("[BluezAgent] Cancel called")
return nil
}
func (a *BluezAgent) Introspect() (string, *dbus.Error) {
return introspectXML, nil
}
func (a *BluezAgent) promptFor(device dbus.ObjectPath, requestType string, fields []string, displayValue *string) (map[string]string, error) {
if a.broker == nil {
return nil, fmt.Errorf("broker not initialized")
}
deviceName, deviceAddr := a.getDeviceInfo(device)
hints := []string{}
if displayValue != nil {
hints = append(hints, *displayValue)
}
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
var passkey *uint32
if requestType == "confirm" || requestType == "display-passkey" {
if displayValue != nil {
if pk, err := strconv.ParseUint(*displayValue, 10, 32); err == nil {
pk32 := uint32(pk)
passkey = &pk32
}
}
}
token, err := a.broker.Ask(ctx, PromptRequest{
DevicePath: string(device),
DeviceName: deviceName,
DeviceAddr: deviceAddr,
RequestType: requestType,
Fields: fields,
Hints: hints,
Passkey: passkey,
})
if err != nil {
return nil, fmt.Errorf("prompt creation failed: %w", err)
}
log.Infof("[BluezAgent] waiting for user response (token=%s)", token)
reply, err := a.broker.Wait(ctx, token)
if err != nil {
if errors.Is(err, errdefs.ErrSecretPromptTimeout) {
return nil, err
}
if reply.Cancel || errors.Is(err, errdefs.ErrSecretPromptCancelled) {
return nil, errdefs.ErrSecretPromptCancelled
}
return nil, err
}
if !reply.Accept && len(fields) > 0 {
return nil, errdefs.ErrSecretPromptCancelled
}
return reply.Secrets, nil
}
func (a *BluezAgent) getDeviceInfo(device dbus.ObjectPath) (string, string) {
obj := a.conn.Object(bluezService, device)
var name, alias, addr string
nameVar, err := obj.GetProperty(device1Iface + ".Name")
if err == nil {
if n, ok := nameVar.Value().(string); ok {
name = n
}
}
aliasVar, err := obj.GetProperty(device1Iface + ".Alias")
if err == nil {
if a, ok := aliasVar.Value().(string); ok {
alias = a
}
}
addrVar, err := obj.GetProperty(device1Iface + ".Address")
if err == nil {
if a, ok := addrVar.Value().(string); ok {
addr = a
}
}
if alias != "" {
return alias, addr
}
if name != "" {
return name, addr
}
return addr, addr
}
func (a *BluezAgent) errorFrom(err error) *dbus.Error {
if errors.Is(err, errdefs.ErrSecretPromptTimeout) {
return dbus.NewError("org.bluez.Error.Canceled", nil)
}
if errors.Is(err, errdefs.ErrSecretPromptCancelled) {
return dbus.NewError("org.bluez.Error.Canceled", nil)
}
return dbus.MakeFailedError(err)
}

View File

@@ -0,0 +1,21 @@
package bluez
import (
"context"
"crypto/rand"
"encoding/hex"
)
type PromptBroker interface {
Ask(ctx context.Context, req PromptRequest) (token string, err error)
Wait(ctx context.Context, token string) (PromptReply, error)
Resolve(token string, reply PromptReply) error
}
func generateToken() (string, error) {
bytes := make([]byte, 16)
if _, err := rand.Read(bytes); err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}

View File

@@ -0,0 +1,220 @@
package bluez
import (
"context"
"testing"
"time"
)
func TestSubscriptionBrokerAskWait(t *testing.T) {
promptReceived := false
broker := NewSubscriptionBroker(func(p PairingPrompt) {
promptReceived = true
if p.Token == "" {
t.Error("expected token to be non-empty")
}
if p.DeviceName != "TestDevice" {
t.Errorf("expected DeviceName=TestDevice, got %s", p.DeviceName)
}
})
ctx := context.Background()
req := PromptRequest{
DevicePath: "/org/bluez/test",
DeviceName: "TestDevice",
DeviceAddr: "AA:BB:CC:DD:EE:FF",
RequestType: "pin",
Fields: []string{"pin"},
}
token, err := broker.Ask(ctx, req)
if err != nil {
t.Fatalf("Ask failed: %v", err)
}
if token == "" {
t.Fatal("expected non-empty token")
}
if !promptReceived {
t.Fatal("expected prompt broadcast to be called")
}
go func() {
time.Sleep(50 * time.Millisecond)
broker.Resolve(token, PromptReply{
Secrets: map[string]string{"pin": "1234"},
Accept: true,
})
}()
reply, err := broker.Wait(ctx, token)
if err != nil {
t.Fatalf("Wait failed: %v", err)
}
if reply.Secrets["pin"] != "1234" {
t.Errorf("expected pin=1234, got %s", reply.Secrets["pin"])
}
if !reply.Accept {
t.Error("expected Accept=true")
}
}
func TestSubscriptionBrokerTimeout(t *testing.T) {
broker := NewSubscriptionBroker(nil)
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
req := PromptRequest{
DevicePath: "/org/bluez/test",
DeviceName: "TestDevice",
RequestType: "passkey",
Fields: []string{"passkey"},
}
token, err := broker.Ask(ctx, req)
if err != nil {
t.Fatalf("Ask failed: %v", err)
}
_, err = broker.Wait(ctx, token)
if err == nil {
t.Fatal("expected timeout error")
}
}
func TestSubscriptionBrokerCancel(t *testing.T) {
broker := NewSubscriptionBroker(nil)
ctx := context.Background()
req := PromptRequest{
DevicePath: "/org/bluez/test",
DeviceName: "TestDevice",
RequestType: "confirm",
Fields: []string{"decision"},
}
token, err := broker.Ask(ctx, req)
if err != nil {
t.Fatalf("Ask failed: %v", err)
}
go func() {
time.Sleep(50 * time.Millisecond)
broker.Resolve(token, PromptReply{
Cancel: true,
})
}()
_, err = broker.Wait(ctx, token)
if err == nil {
t.Fatal("expected cancelled error")
}
}
func TestSubscriptionBrokerUnknownToken(t *testing.T) {
broker := NewSubscriptionBroker(nil)
ctx := context.Background()
_, err := broker.Wait(ctx, "invalid-token")
if err == nil {
t.Fatal("expected error for unknown token")
}
}
func TestGenerateToken(t *testing.T) {
token1, err := generateToken()
if err != nil {
t.Fatalf("generateToken failed: %v", err)
}
token2, err := generateToken()
if err != nil {
t.Fatalf("generateToken failed: %v", err)
}
if token1 == token2 {
t.Error("expected unique tokens")
}
if len(token1) != 32 {
t.Errorf("expected token length 32, got %d", len(token1))
}
}
func TestSubscriptionBrokerResolveUnknownToken(t *testing.T) {
broker := NewSubscriptionBroker(nil)
err := broker.Resolve("unknown-token", PromptReply{
Secrets: map[string]string{"test": "value"},
})
if err == nil {
t.Fatal("expected error for unknown token")
}
}
func TestSubscriptionBrokerMultipleRequests(t *testing.T) {
broker := NewSubscriptionBroker(nil)
ctx := context.Background()
req1 := PromptRequest{
DevicePath: "/org/bluez/test1",
DeviceName: "Device1",
RequestType: "pin",
Fields: []string{"pin"},
}
req2 := PromptRequest{
DevicePath: "/org/bluez/test2",
DeviceName: "Device2",
RequestType: "passkey",
Fields: []string{"passkey"},
}
token1, err := broker.Ask(ctx, req1)
if err != nil {
t.Fatalf("Ask1 failed: %v", err)
}
token2, err := broker.Ask(ctx, req2)
if err != nil {
t.Fatalf("Ask2 failed: %v", err)
}
if token1 == token2 {
t.Error("expected different tokens")
}
go func() {
time.Sleep(50 * time.Millisecond)
broker.Resolve(token1, PromptReply{
Secrets: map[string]string{"pin": "1234"},
Accept: true,
})
broker.Resolve(token2, PromptReply{
Secrets: map[string]string{"passkey": "567890"},
Accept: true,
})
}()
reply1, err := broker.Wait(ctx, token1)
if err != nil {
t.Fatalf("Wait1 failed: %v", err)
}
reply2, err := broker.Wait(ctx, token2)
if err != nil {
t.Fatalf("Wait2 failed: %v", err)
}
if reply1.Secrets["pin"] != "1234" {
t.Errorf("expected pin=1234, got %s", reply1.Secrets["pin"])
}
if reply2.Secrets["passkey"] != "567890" {
t.Errorf("expected passkey=567890, got %s", reply2.Secrets["passkey"])
}
}

View File

@@ -0,0 +1,260 @@
package bluez
import (
"encoding/json"
"fmt"
"net"
"github.com/AvengeMedia/danklinux/internal/server/models"
)
type Request struct {
ID int `json:"id,omitempty"`
Method string `json:"method"`
Params map[string]interface{} `json:"params,omitempty"`
}
type SuccessResult struct {
Success bool `json:"success"`
Message string `json:"message"`
}
type BluetoothEvent struct {
Type string `json:"type"`
Data BluetoothState `json:"data"`
}
func HandleRequest(conn net.Conn, req Request, manager *Manager) {
switch req.Method {
case "bluetooth.getState":
handleGetState(conn, req, manager)
case "bluetooth.startDiscovery":
handleStartDiscovery(conn, req, manager)
case "bluetooth.stopDiscovery":
handleStopDiscovery(conn, req, manager)
case "bluetooth.setPowered":
handleSetPowered(conn, req, manager)
case "bluetooth.pair":
handlePairDevice(conn, req, manager)
case "bluetooth.connect":
handleConnectDevice(conn, req, manager)
case "bluetooth.disconnect":
handleDisconnectDevice(conn, req, manager)
case "bluetooth.remove":
handleRemoveDevice(conn, req, manager)
case "bluetooth.trust":
handleTrustDevice(conn, req, manager)
case "bluetooth.untrust":
handleUntrustDevice(conn, req, manager)
case "bluetooth.subscribe":
handleSubscribe(conn, req, manager)
case "bluetooth.pairing.submit":
handlePairingSubmit(conn, req, manager)
case "bluetooth.pairing.cancel":
handlePairingCancel(conn, req, manager)
default:
models.RespondError(conn, req.ID, fmt.Sprintf("unknown method: %s", req.Method))
}
}
func handleGetState(conn net.Conn, req Request, manager *Manager) {
state := manager.GetState()
models.Respond(conn, req.ID, state)
}
func handleStartDiscovery(conn net.Conn, req Request, manager *Manager) {
if err := manager.StartDiscovery(); err != nil {
models.RespondError(conn, req.ID, err.Error())
return
}
models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "discovery started"})
}
func handleStopDiscovery(conn net.Conn, req Request, manager *Manager) {
if err := manager.StopDiscovery(); err != nil {
models.RespondError(conn, req.ID, err.Error())
return
}
models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "discovery stopped"})
}
func handleSetPowered(conn net.Conn, req Request, manager *Manager) {
powered, ok := req.Params["powered"].(bool)
if !ok {
models.RespondError(conn, req.ID, "missing or invalid 'powered' parameter")
return
}
if err := manager.SetPowered(powered); err != nil {
models.RespondError(conn, req.ID, err.Error())
return
}
models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "powered state updated"})
}
func handlePairDevice(conn net.Conn, req Request, manager *Manager) {
devicePath, ok := req.Params["device"].(string)
if !ok {
models.RespondError(conn, req.ID, "missing or invalid 'device' parameter")
return
}
if err := manager.PairDevice(devicePath); err != nil {
models.RespondError(conn, req.ID, err.Error())
return
}
models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "pairing initiated"})
}
func handleConnectDevice(conn net.Conn, req Request, manager *Manager) {
devicePath, ok := req.Params["device"].(string)
if !ok {
models.RespondError(conn, req.ID, "missing or invalid 'device' parameter")
return
}
if err := manager.ConnectDevice(devicePath); err != nil {
models.RespondError(conn, req.ID, err.Error())
return
}
models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "connecting"})
}
func handleDisconnectDevice(conn net.Conn, req Request, manager *Manager) {
devicePath, ok := req.Params["device"].(string)
if !ok {
models.RespondError(conn, req.ID, "missing or invalid 'device' parameter")
return
}
if err := manager.DisconnectDevice(devicePath); err != nil {
models.RespondError(conn, req.ID, err.Error())
return
}
models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "disconnected"})
}
func handleRemoveDevice(conn net.Conn, req Request, manager *Manager) {
devicePath, ok := req.Params["device"].(string)
if !ok {
models.RespondError(conn, req.ID, "missing or invalid 'device' parameter")
return
}
if err := manager.RemoveDevice(devicePath); err != nil {
models.RespondError(conn, req.ID, err.Error())
return
}
models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "device removed"})
}
func handleTrustDevice(conn net.Conn, req Request, manager *Manager) {
devicePath, ok := req.Params["device"].(string)
if !ok {
models.RespondError(conn, req.ID, "missing or invalid 'device' parameter")
return
}
if err := manager.TrustDevice(devicePath, true); err != nil {
models.RespondError(conn, req.ID, err.Error())
return
}
models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "device trusted"})
}
func handleUntrustDevice(conn net.Conn, req Request, manager *Manager) {
devicePath, ok := req.Params["device"].(string)
if !ok {
models.RespondError(conn, req.ID, "missing or invalid 'device' parameter")
return
}
if err := manager.TrustDevice(devicePath, false); err != nil {
models.RespondError(conn, req.ID, err.Error())
return
}
models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "device untrusted"})
}
func handlePairingSubmit(conn net.Conn, req Request, manager *Manager) {
token, ok := req.Params["token"].(string)
if !ok {
models.RespondError(conn, req.ID, "missing or invalid 'token' parameter")
return
}
secretsRaw, ok := req.Params["secrets"].(map[string]interface{})
secrets := make(map[string]string)
if ok {
for k, v := range secretsRaw {
if str, ok := v.(string); ok {
secrets[k] = str
}
}
}
accept := false
if acceptParam, ok := req.Params["accept"].(bool); ok {
accept = acceptParam
}
if err := manager.SubmitPairing(token, secrets, accept); err != nil {
models.RespondError(conn, req.ID, err.Error())
return
}
models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "pairing response submitted"})
}
func handlePairingCancel(conn net.Conn, req Request, manager *Manager) {
token, ok := req.Params["token"].(string)
if !ok {
models.RespondError(conn, req.ID, "missing or invalid 'token' parameter")
return
}
if err := manager.CancelPairing(token); err != nil {
models.RespondError(conn, req.ID, err.Error())
return
}
models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "pairing cancelled"})
}
func handleSubscribe(conn net.Conn, req Request, manager *Manager) {
clientID := fmt.Sprintf("client-%p", conn)
stateChan := manager.Subscribe(clientID)
defer manager.Unsubscribe(clientID)
initialState := manager.GetState()
event := BluetoothEvent{
Type: "state_changed",
Data: initialState,
}
if err := json.NewEncoder(conn).Encode(models.Response[BluetoothEvent]{
ID: req.ID,
Result: &event,
}); err != nil {
return
}
for state := range stateChan {
event := BluetoothEvent{
Type: "state_changed",
Data: state,
}
if err := json.NewEncoder(conn).Encode(models.Response[BluetoothEvent]{
Result: &event,
}); err != nil {
return
}
}
}

View File

@@ -0,0 +1,41 @@
package bluez
import (
"context"
"testing"
"time"
)
func TestBrokerIntegration(t *testing.T) {
broker := NewSubscriptionBroker(nil)
ctx := context.Background()
req := PromptRequest{
DevicePath: "/org/bluez/test",
DeviceName: "TestDevice",
RequestType: "pin",
Fields: []string{"pin"},
}
token, err := broker.Ask(ctx, req)
if err != nil {
t.Fatalf("Ask failed: %v", err)
}
go func() {
time.Sleep(50 * time.Millisecond)
broker.Resolve(token, PromptReply{
Secrets: map[string]string{"pin": "1234"},
Accept: true,
})
}()
reply, err := broker.Wait(ctx, token)
if err != nil {
t.Fatalf("Wait failed: %v", err)
}
if reply.Secrets["pin"] != "1234" {
t.Errorf("expected pin=1234, got %s", reply.Secrets["pin"])
}
}

View File

@@ -0,0 +1,668 @@
package bluez
import (
"fmt"
"strings"
"sync"
"time"
"github.com/AvengeMedia/danklinux/internal/log"
"github.com/godbus/dbus/v5"
)
const (
adapter1Iface = "org.bluez.Adapter1"
objectMgrIface = "org.freedesktop.DBus.ObjectManager"
propertiesIface = "org.freedesktop.DBus.Properties"
)
func NewManager() (*Manager, error) {
conn, err := dbus.ConnectSystemBus()
if err != nil {
return nil, fmt.Errorf("system bus connection failed: %w", err)
}
m := &Manager{
state: &BluetoothState{
Powered: false,
Discovering: false,
Devices: []Device{},
PairedDevices: []Device{},
ConnectedDevices: []Device{},
},
stateMutex: sync.RWMutex{},
subscribers: make(map[string]chan BluetoothState),
subMutex: sync.RWMutex{},
stopChan: make(chan struct{}),
dbusConn: conn,
signals: make(chan *dbus.Signal, 256),
pairingSubscribers: make(map[string]chan PairingPrompt),
pairingSubMutex: sync.RWMutex{},
dirty: make(chan struct{}, 1),
pendingPairings: make(map[string]bool),
eventQueue: make(chan func(), 32),
}
broker := NewSubscriptionBroker(m.broadcastPairingPrompt)
m.promptBroker = broker
adapter, err := m.findAdapter()
if err != nil {
conn.Close()
return nil, fmt.Errorf("no bluetooth adapter found: %w", err)
}
m.adapterPath = adapter
if err := m.initialize(); err != nil {
conn.Close()
return nil, err
}
if err := m.startAgent(); err != nil {
conn.Close()
return nil, fmt.Errorf("agent start failed: %w", err)
}
if err := m.startSignalPump(); err != nil {
m.Close()
return nil, err
}
m.notifierWg.Add(1)
go m.notifier()
m.eventWg.Add(1)
go m.eventWorker()
return m, nil
}
func (m *Manager) findAdapter() (dbus.ObjectPath, error) {
obj := m.dbusConn.Object(bluezService, dbus.ObjectPath("/"))
var objects map[dbus.ObjectPath]map[string]map[string]dbus.Variant
if err := obj.Call(objectMgrIface+".GetManagedObjects", 0).Store(&objects); err != nil {
return "", err
}
for path, interfaces := range objects {
if _, ok := interfaces[adapter1Iface]; ok {
log.Infof("[BluezManager] found adapter: %s", path)
return path, nil
}
}
return "", fmt.Errorf("no adapter found")
}
func (m *Manager) initialize() error {
if err := m.updateAdapterState(); err != nil {
return err
}
if err := m.updateDevices(); err != nil {
return err
}
return nil
}
func (m *Manager) updateAdapterState() error {
obj := m.dbusConn.Object(bluezService, m.adapterPath)
poweredVar, err := obj.GetProperty(adapter1Iface + ".Powered")
if err != nil {
return err
}
powered, _ := poweredVar.Value().(bool)
discoveringVar, err := obj.GetProperty(adapter1Iface + ".Discovering")
if err != nil {
return err
}
discovering, _ := discoveringVar.Value().(bool)
m.stateMutex.Lock()
m.state.Powered = powered
m.state.Discovering = discovering
m.stateMutex.Unlock()
return nil
}
func (m *Manager) updateDevices() error {
obj := m.dbusConn.Object(bluezService, dbus.ObjectPath("/"))
var objects map[dbus.ObjectPath]map[string]map[string]dbus.Variant
if err := obj.Call(objectMgrIface+".GetManagedObjects", 0).Store(&objects); err != nil {
return err
}
devices := []Device{}
paired := []Device{}
connected := []Device{}
for path, interfaces := range objects {
devProps, ok := interfaces[device1Iface]
if !ok {
continue
}
if !strings.HasPrefix(string(path), string(m.adapterPath)+"/") {
continue
}
dev := m.deviceFromProps(string(path), devProps)
devices = append(devices, dev)
if dev.Paired {
paired = append(paired, dev)
}
if dev.Connected {
connected = append(connected, dev)
}
}
m.stateMutex.Lock()
m.state.Devices = devices
m.state.PairedDevices = paired
m.state.ConnectedDevices = connected
m.stateMutex.Unlock()
return nil
}
func (m *Manager) deviceFromProps(path string, props map[string]dbus.Variant) Device {
dev := Device{Path: path}
if v, ok := props["Address"]; ok {
if addr, ok := v.Value().(string); ok {
dev.Address = addr
}
}
if v, ok := props["Name"]; ok {
if name, ok := v.Value().(string); ok {
dev.Name = name
}
}
if v, ok := props["Alias"]; ok {
if alias, ok := v.Value().(string); ok {
dev.Alias = alias
}
}
if v, ok := props["Paired"]; ok {
if paired, ok := v.Value().(bool); ok {
dev.Paired = paired
}
}
if v, ok := props["Trusted"]; ok {
if trusted, ok := v.Value().(bool); ok {
dev.Trusted = trusted
}
}
if v, ok := props["Blocked"]; ok {
if blocked, ok := v.Value().(bool); ok {
dev.Blocked = blocked
}
}
if v, ok := props["Connected"]; ok {
if connected, ok := v.Value().(bool); ok {
dev.Connected = connected
}
}
if v, ok := props["Class"]; ok {
if class, ok := v.Value().(uint32); ok {
dev.Class = class
}
}
if v, ok := props["Icon"]; ok {
if icon, ok := v.Value().(string); ok {
dev.Icon = icon
}
}
if v, ok := props["RSSI"]; ok {
if rssi, ok := v.Value().(int16); ok {
dev.RSSI = rssi
}
}
if v, ok := props["LegacyPairing"]; ok {
if legacy, ok := v.Value().(bool); ok {
dev.LegacyPairing = legacy
}
}
return dev
}
func (m *Manager) startAgent() error {
if m.promptBroker == nil {
return fmt.Errorf("prompt broker not initialized")
}
agent, err := NewBluezAgent(m.promptBroker)
if err != nil {
return err
}
m.agent = agent
return nil
}
func (m *Manager) startSignalPump() error {
m.dbusConn.Signal(m.signals)
if err := m.dbusConn.AddMatchSignal(
dbus.WithMatchInterface(propertiesIface),
dbus.WithMatchMember("PropertiesChanged"),
); err != nil {
return err
}
if err := m.dbusConn.AddMatchSignal(
dbus.WithMatchInterface(objectMgrIface),
dbus.WithMatchMember("InterfacesAdded"),
); err != nil {
return err
}
if err := m.dbusConn.AddMatchSignal(
dbus.WithMatchInterface(objectMgrIface),
dbus.WithMatchMember("InterfacesRemoved"),
); err != nil {
return err
}
m.sigWG.Add(1)
go func() {
defer m.sigWG.Done()
for {
select {
case <-m.stopChan:
return
case sig, ok := <-m.signals:
if !ok {
return
}
if sig == nil {
continue
}
m.handleSignal(sig)
}
}
}()
return nil
}
func (m *Manager) handleSignal(sig *dbus.Signal) {
switch sig.Name {
case propertiesIface + ".PropertiesChanged":
if len(sig.Body) < 2 {
return
}
iface, ok := sig.Body[0].(string)
if !ok {
return
}
changed, ok := sig.Body[1].(map[string]dbus.Variant)
if !ok {
return
}
switch iface {
case adapter1Iface:
if strings.HasPrefix(string(sig.Path), string(m.adapterPath)) {
m.handleAdapterPropertiesChanged(changed)
}
case device1Iface:
m.handleDevicePropertiesChanged(sig.Path, changed)
}
case objectMgrIface + ".InterfacesAdded":
m.notifySubscribers()
case objectMgrIface + ".InterfacesRemoved":
m.notifySubscribers()
}
}
func (m *Manager) handleAdapterPropertiesChanged(changed map[string]dbus.Variant) {
m.stateMutex.Lock()
dirty := false
if v, ok := changed["Powered"]; ok {
if powered, ok := v.Value().(bool); ok {
m.state.Powered = powered
dirty = true
}
}
if v, ok := changed["Discovering"]; ok {
if discovering, ok := v.Value().(bool); ok {
m.state.Discovering = discovering
dirty = true
}
}
m.stateMutex.Unlock()
if dirty {
m.notifySubscribers()
}
}
func (m *Manager) handleDevicePropertiesChanged(path dbus.ObjectPath, changed map[string]dbus.Variant) {
pairedVar, hasPaired := changed["Paired"]
_, hasConnected := changed["Connected"]
_, hasTrusted := changed["Trusted"]
if hasPaired {
if paired, ok := pairedVar.Value().(bool); ok && paired {
devicePath := string(path)
m.pendingPairingsMux.Lock()
wasPending := m.pendingPairings[devicePath]
if wasPending {
delete(m.pendingPairings, devicePath)
}
m.pendingPairingsMux.Unlock()
if wasPending {
select {
case m.eventQueue <- func() {
time.Sleep(300 * time.Millisecond)
log.Infof("[Bluetooth] Auto-connecting newly paired device: %s", devicePath)
if err := m.ConnectDevice(devicePath); err != nil {
log.Warnf("[Bluetooth] Auto-connect failed: %v", err)
}
}:
default:
}
}
}
}
if hasPaired || hasConnected || hasTrusted {
select {
case m.eventQueue <- func() {
time.Sleep(100 * time.Millisecond)
m.updateDevices()
m.notifySubscribers()
}:
default:
}
}
}
func (m *Manager) eventWorker() {
defer m.eventWg.Done()
for {
select {
case <-m.stopChan:
return
case event := <-m.eventQueue:
event()
}
}
}
func (m *Manager) notifier() {
defer m.notifierWg.Done()
const minGap = 200 * time.Millisecond
timer := time.NewTimer(minGap)
timer.Stop()
var pending bool
for {
select {
case <-m.stopChan:
timer.Stop()
return
case <-m.dirty:
if pending {
continue
}
pending = true
timer.Reset(minGap)
case <-timer.C:
if !pending {
continue
}
m.updateDevices()
m.subMutex.RLock()
if len(m.subscribers) == 0 {
m.subMutex.RUnlock()
pending = false
continue
}
currentState := m.snapshotState()
if m.lastNotifiedState != nil && !stateChanged(m.lastNotifiedState, &currentState) {
m.subMutex.RUnlock()
pending = false
continue
}
for _, ch := range m.subscribers {
select {
case ch <- currentState:
default:
}
}
m.subMutex.RUnlock()
stateCopy := currentState
m.lastNotifiedState = &stateCopy
pending = false
}
}
}
func (m *Manager) notifySubscribers() {
select {
case m.dirty <- struct{}{}:
default:
}
}
func (m *Manager) GetState() BluetoothState {
return m.snapshotState()
}
func (m *Manager) snapshotState() BluetoothState {
m.stateMutex.RLock()
defer m.stateMutex.RUnlock()
s := *m.state
s.Devices = append([]Device(nil), m.state.Devices...)
s.PairedDevices = append([]Device(nil), m.state.PairedDevices...)
s.ConnectedDevices = append([]Device(nil), m.state.ConnectedDevices...)
return s
}
func (m *Manager) Subscribe(id string) chan BluetoothState {
ch := make(chan BluetoothState, 64)
m.subMutex.Lock()
m.subscribers[id] = ch
m.subMutex.Unlock()
return ch
}
func (m *Manager) Unsubscribe(id string) {
m.subMutex.Lock()
if ch, ok := m.subscribers[id]; ok {
close(ch)
delete(m.subscribers, id)
}
m.subMutex.Unlock()
}
func (m *Manager) SubscribePairing(id string) chan PairingPrompt {
ch := make(chan PairingPrompt, 16)
m.pairingSubMutex.Lock()
m.pairingSubscribers[id] = ch
m.pairingSubMutex.Unlock()
return ch
}
func (m *Manager) UnsubscribePairing(id string) {
m.pairingSubMutex.Lock()
if ch, ok := m.pairingSubscribers[id]; ok {
close(ch)
delete(m.pairingSubscribers, id)
}
m.pairingSubMutex.Unlock()
}
func (m *Manager) broadcastPairingPrompt(prompt PairingPrompt) {
m.pairingSubMutex.RLock()
defer m.pairingSubMutex.RUnlock()
for _, ch := range m.pairingSubscribers {
select {
case ch <- prompt:
default:
}
}
}
func (m *Manager) SubmitPairing(token string, secrets map[string]string, accept bool) error {
if m.promptBroker == nil {
return fmt.Errorf("prompt broker not initialized")
}
return m.promptBroker.Resolve(token, PromptReply{
Secrets: secrets,
Accept: accept,
Cancel: false,
})
}
func (m *Manager) CancelPairing(token string) error {
if m.promptBroker == nil {
return fmt.Errorf("prompt broker not initialized")
}
return m.promptBroker.Resolve(token, PromptReply{
Cancel: true,
})
}
func (m *Manager) StartDiscovery() error {
obj := m.dbusConn.Object(bluezService, m.adapterPath)
return obj.Call(adapter1Iface+".StartDiscovery", 0).Err
}
func (m *Manager) StopDiscovery() error {
obj := m.dbusConn.Object(bluezService, m.adapterPath)
return obj.Call(adapter1Iface+".StopDiscovery", 0).Err
}
func (m *Manager) SetPowered(powered bool) error {
obj := m.dbusConn.Object(bluezService, m.adapterPath)
return obj.Call(propertiesIface+".Set", 0, adapter1Iface, "Powered", dbus.MakeVariant(powered)).Err
}
func (m *Manager) PairDevice(devicePath string) error {
m.pendingPairingsMux.Lock()
m.pendingPairings[devicePath] = true
m.pendingPairingsMux.Unlock()
obj := m.dbusConn.Object(bluezService, dbus.ObjectPath(devicePath))
err := obj.Call(device1Iface+".Pair", 0).Err
if err != nil {
m.pendingPairingsMux.Lock()
delete(m.pendingPairings, devicePath)
m.pendingPairingsMux.Unlock()
}
return err
}
func (m *Manager) ConnectDevice(devicePath string) error {
obj := m.dbusConn.Object(bluezService, dbus.ObjectPath(devicePath))
return obj.Call(device1Iface+".Connect", 0).Err
}
func (m *Manager) DisconnectDevice(devicePath string) error {
obj := m.dbusConn.Object(bluezService, dbus.ObjectPath(devicePath))
return obj.Call(device1Iface+".Disconnect", 0).Err
}
func (m *Manager) RemoveDevice(devicePath string) error {
obj := m.dbusConn.Object(bluezService, m.adapterPath)
return obj.Call(adapter1Iface+".RemoveDevice", 0, dbus.ObjectPath(devicePath)).Err
}
func (m *Manager) TrustDevice(devicePath string, trusted bool) error {
obj := m.dbusConn.Object(bluezService, dbus.ObjectPath(devicePath))
return obj.Call(propertiesIface+".Set", 0, device1Iface, "Trusted", dbus.MakeVariant(trusted)).Err
}
func (m *Manager) Close() {
close(m.stopChan)
m.notifierWg.Wait()
m.eventWg.Wait()
m.sigWG.Wait()
if m.signals != nil {
m.dbusConn.RemoveSignal(m.signals)
close(m.signals)
}
if m.agent != nil {
m.agent.Close()
}
m.subMutex.Lock()
for _, ch := range m.subscribers {
close(ch)
}
m.subscribers = make(map[string]chan BluetoothState)
m.subMutex.Unlock()
m.pairingSubMutex.Lock()
for _, ch := range m.pairingSubscribers {
close(ch)
}
m.pairingSubscribers = make(map[string]chan PairingPrompt)
m.pairingSubMutex.Unlock()
if m.dbusConn != nil {
m.dbusConn.Close()
}
}
func stateChanged(old, new *BluetoothState) bool {
if old.Powered != new.Powered {
return true
}
if old.Discovering != new.Discovering {
return true
}
if len(old.Devices) != len(new.Devices) {
return true
}
if len(old.PairedDevices) != len(new.PairedDevices) {
return true
}
if len(old.ConnectedDevices) != len(new.ConnectedDevices) {
return true
}
for i := range old.Devices {
if old.Devices[i].Path != new.Devices[i].Path {
return true
}
if old.Devices[i].Paired != new.Devices[i].Paired {
return true
}
if old.Devices[i].Connected != new.Devices[i].Connected {
return true
}
}
return false
}

View File

@@ -0,0 +1,99 @@
package bluez
import (
"context"
"fmt"
"sync"
"github.com/AvengeMedia/danklinux/internal/errdefs"
)
type SubscriptionBroker struct {
mu sync.RWMutex
pending map[string]chan PromptReply
requests map[string]PromptRequest
broadcastPrompt func(PairingPrompt)
}
func NewSubscriptionBroker(broadcastPrompt func(PairingPrompt)) PromptBroker {
return &SubscriptionBroker{
pending: make(map[string]chan PromptReply),
requests: make(map[string]PromptRequest),
broadcastPrompt: broadcastPrompt,
}
}
func (b *SubscriptionBroker) Ask(ctx context.Context, req PromptRequest) (string, error) {
token, err := generateToken()
if err != nil {
return "", err
}
replyChan := make(chan PromptReply, 1)
b.mu.Lock()
b.pending[token] = replyChan
b.requests[token] = req
b.mu.Unlock()
if b.broadcastPrompt != nil {
prompt := PairingPrompt{
Token: token,
DevicePath: req.DevicePath,
DeviceName: req.DeviceName,
DeviceAddr: req.DeviceAddr,
RequestType: req.RequestType,
Fields: req.Fields,
Hints: req.Hints,
Passkey: req.Passkey,
}
b.broadcastPrompt(prompt)
}
return token, nil
}
func (b *SubscriptionBroker) Wait(ctx context.Context, token string) (PromptReply, error) {
b.mu.RLock()
replyChan, exists := b.pending[token]
b.mu.RUnlock()
if !exists {
return PromptReply{}, fmt.Errorf("unknown token: %s", token)
}
select {
case <-ctx.Done():
b.cleanup(token)
return PromptReply{}, errdefs.ErrSecretPromptTimeout
case reply := <-replyChan:
b.cleanup(token)
if reply.Cancel {
return reply, errdefs.ErrSecretPromptCancelled
}
return reply, nil
}
}
func (b *SubscriptionBroker) Resolve(token string, reply PromptReply) error {
b.mu.RLock()
replyChan, exists := b.pending[token]
b.mu.RUnlock()
if !exists {
return fmt.Errorf("unknown or expired token: %s", token)
}
select {
case replyChan <- reply:
return nil
default:
return fmt.Errorf("failed to deliver reply for token: %s", token)
}
}
func (b *SubscriptionBroker) cleanup(token string) {
b.mu.Lock()
delete(b.pending, token)
delete(b.requests, token)
b.mu.Unlock()
}

View File

@@ -0,0 +1,80 @@
package bluez
import (
"sync"
"github.com/godbus/dbus/v5"
)
type BluetoothState struct {
Powered bool `json:"powered"`
Discovering bool `json:"discovering"`
Devices []Device `json:"devices"`
PairedDevices []Device `json:"pairedDevices"`
ConnectedDevices []Device `json:"connectedDevices"`
}
type Device struct {
Path string `json:"path"`
Address string `json:"address"`
Name string `json:"name"`
Alias string `json:"alias"`
Paired bool `json:"paired"`
Trusted bool `json:"trusted"`
Blocked bool `json:"blocked"`
Connected bool `json:"connected"`
Class uint32 `json:"class"`
Icon string `json:"icon"`
RSSI int16 `json:"rssi"`
LegacyPairing bool `json:"legacyPairing"`
}
type PromptRequest struct {
DevicePath string `json:"devicePath"`
DeviceName string `json:"deviceName"`
DeviceAddr string `json:"deviceAddr"`
RequestType string `json:"requestType"`
Fields []string `json:"fields"`
Hints []string `json:"hints"`
Passkey *uint32 `json:"passkey,omitempty"`
}
type PromptReply struct {
Secrets map[string]string `json:"secrets"`
Accept bool `json:"accept"`
Cancel bool `json:"cancel"`
}
type PairingPrompt struct {
Token string `json:"token"`
DevicePath string `json:"devicePath"`
DeviceName string `json:"deviceName"`
DeviceAddr string `json:"deviceAddr"`
RequestType string `json:"requestType"`
Fields []string `json:"fields"`
Hints []string `json:"hints"`
Passkey *uint32 `json:"passkey,omitempty"`
}
type Manager struct {
state *BluetoothState
stateMutex sync.RWMutex
subscribers map[string]chan BluetoothState
subMutex sync.RWMutex
stopChan chan struct{}
dbusConn *dbus.Conn
signals chan *dbus.Signal
sigWG sync.WaitGroup
agent *BluezAgent
promptBroker PromptBroker
pairingSubscribers map[string]chan PairingPrompt
pairingSubMutex sync.RWMutex
dirty chan struct{}
notifierWg sync.WaitGroup
lastNotifiedState *BluetoothState
adapterPath dbus.ObjectPath
pendingPairings map[string]bool
pendingPairingsMux sync.Mutex
eventQueue chan func()
eventWg sync.WaitGroup
}

View File

@@ -0,0 +1,210 @@
package bluez
import (
"encoding/json"
"testing"
)
func TestBluetoothStateJSON(t *testing.T) {
state := BluetoothState{
Powered: true,
Discovering: false,
Devices: []Device{
{
Path: "/org/bluez/hci0/dev_AA_BB_CC_DD_EE_FF",
Address: "AA:BB:CC:DD:EE:FF",
Name: "TestDevice",
Alias: "My Device",
Paired: true,
Trusted: false,
Connected: true,
Class: 0x240418,
Icon: "audio-headset",
RSSI: -50,
},
},
PairedDevices: []Device{
{
Path: "/org/bluez/hci0/dev_AA_BB_CC_DD_EE_FF",
Address: "AA:BB:CC:DD:EE:FF",
Paired: true,
},
},
ConnectedDevices: []Device{
{
Path: "/org/bluez/hci0/dev_AA_BB_CC_DD_EE_FF",
Address: "AA:BB:CC:DD:EE:FF",
Connected: true,
},
},
}
data, err := json.Marshal(state)
if err != nil {
t.Fatalf("failed to marshal state: %v", err)
}
var decoded BluetoothState
if err := json.Unmarshal(data, &decoded); err != nil {
t.Fatalf("failed to unmarshal state: %v", err)
}
if decoded.Powered != state.Powered {
t.Errorf("expected Powered=%v, got %v", state.Powered, decoded.Powered)
}
if len(decoded.Devices) != 1 {
t.Fatalf("expected 1 device, got %d", len(decoded.Devices))
}
if decoded.Devices[0].Address != "AA:BB:CC:DD:EE:FF" {
t.Errorf("expected address AA:BB:CC:DD:EE:FF, got %s", decoded.Devices[0].Address)
}
}
func TestDeviceJSON(t *testing.T) {
device := Device{
Path: "/org/bluez/hci0/dev_AA_BB_CC_DD_EE_FF",
Address: "AA:BB:CC:DD:EE:FF",
Name: "TestDevice",
Alias: "My Device",
Paired: true,
Trusted: true,
Blocked: false,
Connected: true,
Class: 0x240418,
Icon: "audio-headset",
RSSI: -50,
LegacyPairing: false,
}
data, err := json.Marshal(device)
if err != nil {
t.Fatalf("failed to marshal device: %v", err)
}
var decoded Device
if err := json.Unmarshal(data, &decoded); err != nil {
t.Fatalf("failed to unmarshal device: %v", err)
}
if decoded.Address != device.Address {
t.Errorf("expected Address=%s, got %s", device.Address, decoded.Address)
}
if decoded.Name != device.Name {
t.Errorf("expected Name=%s, got %s", device.Name, decoded.Name)
}
if decoded.Paired != device.Paired {
t.Errorf("expected Paired=%v, got %v", device.Paired, decoded.Paired)
}
if decoded.RSSI != device.RSSI {
t.Errorf("expected RSSI=%d, got %d", device.RSSI, decoded.RSSI)
}
}
func TestPairingPromptJSON(t *testing.T) {
passkey := uint32(123456)
prompt := PairingPrompt{
Token: "test-token",
DevicePath: "/org/bluez/hci0/dev_AA_BB_CC_DD_EE_FF",
DeviceName: "TestDevice",
DeviceAddr: "AA:BB:CC:DD:EE:FF",
RequestType: "confirm",
Fields: []string{"decision"},
Hints: []string{},
Passkey: &passkey,
}
data, err := json.Marshal(prompt)
if err != nil {
t.Fatalf("failed to marshal prompt: %v", err)
}
var decoded PairingPrompt
if err := json.Unmarshal(data, &decoded); err != nil {
t.Fatalf("failed to unmarshal prompt: %v", err)
}
if decoded.Token != prompt.Token {
t.Errorf("expected Token=%s, got %s", prompt.Token, decoded.Token)
}
if decoded.DeviceName != prompt.DeviceName {
t.Errorf("expected DeviceName=%s, got %s", prompt.DeviceName, decoded.DeviceName)
}
if decoded.Passkey == nil {
t.Fatal("expected non-nil Passkey")
}
if *decoded.Passkey != *prompt.Passkey {
t.Errorf("expected Passkey=%d, got %d", *prompt.Passkey, *decoded.Passkey)
}
}
func TestPromptReplyJSON(t *testing.T) {
reply := PromptReply{
Secrets: map[string]string{
"pin": "1234",
"passkey": "567890",
},
Accept: true,
Cancel: false,
}
data, err := json.Marshal(reply)
if err != nil {
t.Fatalf("failed to marshal reply: %v", err)
}
var decoded PromptReply
if err := json.Unmarshal(data, &decoded); err != nil {
t.Fatalf("failed to unmarshal reply: %v", err)
}
if decoded.Secrets["pin"] != reply.Secrets["pin"] {
t.Errorf("expected pin=%s, got %s", reply.Secrets["pin"], decoded.Secrets["pin"])
}
if decoded.Accept != reply.Accept {
t.Errorf("expected Accept=%v, got %v", reply.Accept, decoded.Accept)
}
}
func TestPromptRequestJSON(t *testing.T) {
passkey := uint32(123456)
req := PromptRequest{
DevicePath: "/org/bluez/hci0/dev_AA_BB_CC_DD_EE_FF",
DeviceName: "TestDevice",
DeviceAddr: "AA:BB:CC:DD:EE:FF",
RequestType: "confirm",
Fields: []string{"decision"},
Hints: []string{"hint1", "hint2"},
Passkey: &passkey,
}
data, err := json.Marshal(req)
if err != nil {
t.Fatalf("failed to marshal request: %v", err)
}
var decoded PromptRequest
if err := json.Unmarshal(data, &decoded); err != nil {
t.Fatalf("failed to unmarshal request: %v", err)
}
if decoded.DevicePath != req.DevicePath {
t.Errorf("expected DevicePath=%s, got %s", req.DevicePath, decoded.DevicePath)
}
if decoded.RequestType != req.RequestType {
t.Errorf("expected RequestType=%s, got %s", req.RequestType, decoded.RequestType)
}
if len(decoded.Fields) != len(req.Fields) {
t.Errorf("expected %d fields, got %d", len(req.Fields), len(decoded.Fields))
}
}

View File

@@ -0,0 +1,485 @@
package brightness
import (
"encoding/binary"
"fmt"
"math"
"os"
"path/filepath"
"strings"
"syscall"
"time"
"unsafe"
"github.com/AvengeMedia/danklinux/internal/log"
"golang.org/x/sys/unix"
)
const (
I2C_SLAVE = 0x0703
DDCCI_ADDR = 0x37
DDCCI_VCP_GET = 0x01
DDCCI_VCP_SET = 0x03
VCP_BRIGHTNESS = 0x10
DDC_SOURCE_ADDR = 0x51
)
func NewDDCBackend() (*DDCBackend, error) {
b := &DDCBackend{
devices: make(map[string]*ddcDevice),
scanInterval: 30 * time.Second,
debounceTimers: make(map[string]*time.Timer),
debouncePending: make(map[string]ddcPendingSet),
}
if err := b.scanI2CDevices(); err != nil {
return nil, err
}
return b, nil
}
func (b *DDCBackend) scanI2CDevices() error {
b.scanMutex.Lock()
lastScan := b.lastScan
b.scanMutex.Unlock()
if time.Since(lastScan) < b.scanInterval {
return nil
}
b.scanMutex.Lock()
defer b.scanMutex.Unlock()
if time.Since(b.lastScan) < b.scanInterval {
return nil
}
b.devicesMutex.Lock()
defer b.devicesMutex.Unlock()
b.devices = make(map[string]*ddcDevice)
for i := 0; i < 32; i++ {
busPath := fmt.Sprintf("/dev/i2c-%d", i)
if _, err := os.Stat(busPath); os.IsNotExist(err) {
continue
}
// Skip SMBus, GPU internal buses (e.g. AMDGPU SMU) to prevent GPU hangs
if isIgnorableI2CBus(i) {
log.Debugf("Skipping ignorable i2c-%d", i)
continue
}
dev, err := b.probeDDCDevice(i)
if err != nil || dev == nil {
continue
}
id := fmt.Sprintf("ddc:i2c-%d", i)
dev.id = id
b.devices[id] = dev
log.Debugf("found DDC device on i2c-%d", i)
}
b.lastScan = time.Now()
return nil
}
func (b *DDCBackend) probeDDCDevice(bus int) (*ddcDevice, error) {
busPath := fmt.Sprintf("/dev/i2c-%d", bus)
fd, err := syscall.Open(busPath, syscall.O_RDWR, 0)
if err != nil {
return nil, err
}
defer syscall.Close(fd)
if _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), I2C_SLAVE, uintptr(DDCCI_ADDR)); errno != 0 {
return nil, errno
}
dummy := make([]byte, 32)
syscall.Read(fd, dummy)
writebuf := []byte{0x00}
n, err := syscall.Write(fd, writebuf)
if err == nil && n == len(writebuf) {
name := b.getDDCName(bus)
dev := &ddcDevice{
bus: bus,
addr: DDCCI_ADDR,
name: name,
}
b.readInitialBrightness(fd, dev)
return dev, nil
}
readbuf := make([]byte, 4)
n, err = syscall.Read(fd, readbuf)
if err != nil || n == 0 {
return nil, fmt.Errorf("x37 unresponsive")
}
name := b.getDDCName(bus)
dev := &ddcDevice{
bus: bus,
addr: DDCCI_ADDR,
name: name,
}
b.readInitialBrightness(fd, dev)
return dev, nil
}
func (b *DDCBackend) getDDCName(bus int) string {
sysfsPath := fmt.Sprintf("/sys/class/i2c-adapter/i2c-%d/name", bus)
data, err := os.ReadFile(sysfsPath)
if err != nil {
return fmt.Sprintf("I2C-%d", bus)
}
name := strings.TrimSpace(string(data))
if name == "" {
name = fmt.Sprintf("I2C-%d", bus)
}
return name
}
func (b *DDCBackend) readInitialBrightness(fd int, dev *ddcDevice) {
cap, err := b.getVCPFeature(fd, VCP_BRIGHTNESS)
if err != nil {
log.Debugf("failed to read initial brightness for %s: %v", dev.name, err)
return
}
dev.max = cap.max
dev.lastBrightness = cap.current
log.Debugf("initialized %s with brightness %d/%d", dev.name, cap.current, cap.max)
}
func (b *DDCBackend) GetDevices() ([]Device, error) {
if err := b.scanI2CDevices(); err != nil {
log.Debugf("DDC scan error: %v", err)
}
b.devicesMutex.Lock()
defer b.devicesMutex.Unlock()
devices := make([]Device, 0, len(b.devices))
for id, dev := range b.devices {
devices = append(devices, Device{
Class: ClassDDC,
ID: id,
Name: dev.name,
Current: dev.lastBrightness,
Max: dev.max,
CurrentPercent: dev.lastBrightness,
Backend: "ddc",
})
}
return devices, nil
}
func (b *DDCBackend) SetBrightness(id string, value int, exponential bool, callback func()) error {
return b.SetBrightnessWithExponent(id, value, exponential, 1.2, callback)
}
func (b *DDCBackend) SetBrightnessWithExponent(id string, value int, exponential bool, exponent float64, callback func()) error {
b.devicesMutex.RLock()
_, ok := b.devices[id]
b.devicesMutex.RUnlock()
if !ok {
return fmt.Errorf("device not found: %s", id)
}
if value < 0 {
return fmt.Errorf("value out of range: %d", value)
}
b.debounceMutex.Lock()
defer b.debounceMutex.Unlock()
b.debouncePending[id] = ddcPendingSet{
percent: value,
callback: callback,
}
if timer, exists := b.debounceTimers[id]; exists {
timer.Reset(200 * time.Millisecond)
} else {
b.debounceTimers[id] = time.AfterFunc(200*time.Millisecond, func() {
b.debounceMutex.Lock()
pending, exists := b.debouncePending[id]
if exists {
delete(b.debouncePending, id)
}
b.debounceMutex.Unlock()
if !exists {
return
}
err := b.setBrightnessImmediateWithExponent(id, pending.percent, exponential, exponent)
if err != nil {
log.Debugf("Failed to set brightness for %s: %v", id, err)
}
if pending.callback != nil {
pending.callback()
}
})
}
return nil
}
func (b *DDCBackend) setBrightnessImmediate(id string, value int, exponential bool) error {
return b.setBrightnessImmediateWithExponent(id, value, exponential, 1.2)
}
func (b *DDCBackend) setBrightnessImmediateWithExponent(id string, value int, exponential bool, exponent float64) error {
b.devicesMutex.RLock()
dev, ok := b.devices[id]
b.devicesMutex.RUnlock()
if !ok {
return fmt.Errorf("device not found: %s", id)
}
busPath := fmt.Sprintf("/dev/i2c-%d", dev.bus)
fd, err := syscall.Open(busPath, syscall.O_RDWR, 0)
if err != nil {
return fmt.Errorf("open i2c device: %w", err)
}
defer syscall.Close(fd)
if _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), I2C_SLAVE, uintptr(dev.addr)); errno != 0 {
return fmt.Errorf("set i2c slave addr: %w", errno)
}
max := dev.max
if max == 0 {
cap, err := b.getVCPFeature(fd, VCP_BRIGHTNESS)
if err != nil {
return fmt.Errorf("get current capability: %w", err)
}
max = cap.max
b.devicesMutex.Lock()
dev.max = max
b.devicesMutex.Unlock()
}
if err := b.setVCPFeature(fd, VCP_BRIGHTNESS, value); err != nil {
return fmt.Errorf("set vcp feature: %w", err)
}
log.Debugf("set %s to %d/%d", id, value, max)
b.devicesMutex.Lock()
dev.max = max
dev.lastBrightness = value
b.devicesMutex.Unlock()
return nil
}
func (b *DDCBackend) getVCPFeature(fd int, vcp byte) (*ddcCapability, error) {
for flushTry := 0; flushTry < 3; flushTry++ {
dummy := make([]byte, 32)
n, _ := syscall.Read(fd, dummy)
if n == 0 {
break
}
time.Sleep(20 * time.Millisecond)
}
data := []byte{
DDCCI_VCP_GET,
vcp,
}
payload := []byte{
DDC_SOURCE_ADDR,
byte(len(data)) | 0x80,
}
payload = append(payload, data...)
payload = append(payload, ddcciChecksum(payload))
n, err := syscall.Write(fd, payload)
if err != nil || n != len(payload) {
return nil, fmt.Errorf("write i2c: %w", err)
}
time.Sleep(50 * time.Millisecond)
pollFds := []unix.PollFd{
{
Fd: int32(fd),
Events: unix.POLLIN,
},
}
pollTimeout := 200
pollResult, err := unix.Poll(pollFds, pollTimeout)
if err != nil {
return nil, fmt.Errorf("poll i2c: %w", err)
}
if pollResult == 0 {
return nil, fmt.Errorf("poll timeout after %dms", pollTimeout)
}
if pollFds[0].Revents&unix.POLLIN == 0 {
return nil, fmt.Errorf("poll returned but POLLIN not set")
}
response := make([]byte, 12)
n, err = syscall.Read(fd, response)
if err != nil || n < 8 {
return nil, fmt.Errorf("read i2c: %w", err)
}
if response[0] != 0x6E || response[2] != 0x02 {
return nil, fmt.Errorf("invalid ddc response")
}
resultCode := response[3]
if resultCode != 0x00 {
return nil, fmt.Errorf("vcp feature not supported")
}
responseVCP := response[4]
if responseVCP != vcp {
return nil, fmt.Errorf("vcp mismatch: wanted 0x%02x, got 0x%02x", vcp, responseVCP)
}
maxHigh := response[6]
maxLow := response[7]
currentHigh := response[8]
currentLow := response[9]
max := int(binary.BigEndian.Uint16([]byte{maxHigh, maxLow}))
current := int(binary.BigEndian.Uint16([]byte{currentHigh, currentLow}))
return &ddcCapability{
vcp: vcp,
max: max,
current: current,
}, nil
}
func ddcciChecksum(payload []byte) byte {
sum := byte(0x6E)
for _, b := range payload {
sum ^= b
}
return sum
}
func (b *DDCBackend) setVCPFeature(fd int, vcp byte, value int) error {
data := []byte{
DDCCI_VCP_SET,
vcp,
byte(value >> 8),
byte(value & 0xFF),
}
payload := []byte{
DDC_SOURCE_ADDR,
byte(len(data)) | 0x80,
}
payload = append(payload, data...)
payload = append(payload, ddcciChecksum(payload))
if _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), I2C_SLAVE, uintptr(DDCCI_ADDR)); errno != 0 {
return fmt.Errorf("set i2c slave for write: %w", errno)
}
n, err := syscall.Write(fd, payload)
if err != nil || n != len(payload) {
return fmt.Errorf("write i2c: wrote %d/%d: %w", n, len(payload), err)
}
time.Sleep(50 * time.Millisecond)
return nil
}
func (b *DDCBackend) percentToValue(percent int, max int, exponential bool) int {
const minValue = 1
if percent == 0 {
return minValue
}
usableRange := max - minValue
var value int
if exponential {
const exponent = 2.0
normalizedPercent := float64(percent) / 100.0
hardwarePercent := math.Pow(normalizedPercent, 1.0/exponent)
value = minValue + int(math.Round(hardwarePercent*float64(usableRange)))
} else {
value = minValue + ((percent - 1) * usableRange / 99)
}
if value < minValue {
value = minValue
}
if value > max {
value = max
}
return value
}
func (b *DDCBackend) valueToPercent(value int, max int, exponential bool) int {
const minValue = 1
if max == 0 {
return 0
}
if value <= minValue {
return 1
}
usableRange := max - minValue
if usableRange == 0 {
return 100
}
var percent int
if exponential {
const exponent = 2.0
linearPercent := 1 + ((value - minValue) * 99 / usableRange)
normalizedLinear := float64(linearPercent) / 100.0
expPercent := math.Pow(normalizedLinear, exponent)
percent = int(math.Round(expPercent * 100.0))
} else {
percent = 1 + ((value - minValue) * 99 / usableRange)
}
if percent > 100 {
percent = 100
}
if percent < 1 {
percent = 1
}
return percent
}
func (b *DDCBackend) Close() {
}
var _ = unsafe.Sizeof(0)
var _ = filepath.Join

View File

@@ -0,0 +1,135 @@
package brightness
import (
"fmt"
"os"
"path/filepath"
"strconv"
"strings"
"github.com/AvengeMedia/danklinux/internal/log"
)
// isIgnorableI2CBus checks if an I2C bus should be skipped during DDC probing.
// Based on ddcutil's sysfs_is_ignorable_i2c_device() (sysfs_base.c:1441)
func isIgnorableI2CBus(busno int) bool {
name := getI2CDeviceSysfsName(busno)
driver := getI2CSysfsDriver(busno)
if name != "" && isIgnorableI2CDeviceName(name, driver) {
log.Debugf("i2c-%d: ignoring '%s' (driver: %s)", busno, name, driver)
return true
}
// Only probe display adapters (0x03xxxx) and docking stations (0x0axxxx)
class := getI2CDeviceSysfsClass(busno)
if class != 0 {
classHigh := class & 0xFFFF0000
ignorable := (classHigh != 0x030000 && classHigh != 0x0A0000)
if ignorable {
log.Debugf("i2c-%d: ignoring class 0x%08x", busno, class)
}
return ignorable
}
return false
}
// Based on ddcutil's ignorable_i2c_device_sysfs_name() (sysfs_base.c:1408)
func isIgnorableI2CDeviceName(name, driver string) bool {
ignorablePrefixes := []string{
"SMBus",
"Synopsys DesignWare",
"soc:i2cdsi",
"smu",
"mac-io",
"u4",
"AMDGPU SMU", // AMD Navi2+ - probing hangs GPU
}
for _, prefix := range ignorablePrefixes {
if strings.HasPrefix(name, prefix) {
return true
}
}
// nouveau driver: only nvkm-* buses are valid
if driver == "nouveau" && !strings.HasPrefix(name, "nvkm-") {
return true
}
return false
}
// Based on ddcutil's get_i2c_device_sysfs_name() (sysfs_base.c:1175)
func getI2CDeviceSysfsName(busno int) string {
path := fmt.Sprintf("/sys/bus/i2c/devices/i2c-%d/name", busno)
data, err := os.ReadFile(path)
if err != nil {
return ""
}
return strings.TrimSpace(string(data))
}
// Based on ddcutil's get_i2c_device_sysfs_class() (sysfs_base.c:1380)
func getI2CDeviceSysfsClass(busno int) uint32 {
classPath := fmt.Sprintf("/sys/bus/i2c/devices/i2c-%d/device/class", busno)
data, err := os.ReadFile(classPath)
if err != nil {
classPath = fmt.Sprintf("/sys/bus/i2c/devices/i2c-%d/device/device/device/class", busno)
data, err = os.ReadFile(classPath)
if err != nil {
return 0
}
}
classStr := strings.TrimSpace(string(data))
classStr = strings.TrimPrefix(classStr, "0x")
class, err := strconv.ParseUint(classStr, 16, 32)
if err != nil {
return 0
}
return uint32(class)
}
// Based on ddcutil's get_i2c_sysfs_driver_by_busno() (sysfs_base.c:1284)
func getI2CSysfsDriver(busno int) string {
devicePath := fmt.Sprintf("/sys/bus/i2c/devices/i2c-%d", busno)
adapterPath, err := findI2CAdapter(devicePath)
if err != nil {
return ""
}
driverLink := filepath.Join(adapterPath, "driver")
target, err := os.Readlink(driverLink)
if err != nil {
return ""
}
return filepath.Base(target)
}
func findI2CAdapter(devicePath string) (string, error) {
currentPath := devicePath
for depth := 0; depth < 10; depth++ {
if _, err := os.Stat(filepath.Join(currentPath, "name")); err == nil {
return currentPath, nil
}
deviceLink := filepath.Join(currentPath, "device")
target, err := os.Readlink(deviceLink)
if err != nil {
break
}
if !filepath.IsAbs(target) {
target = filepath.Join(filepath.Dir(currentPath), target)
}
currentPath = filepath.Clean(target)
}
return "", fmt.Errorf("could not find adapter for %s", devicePath)
}

View File

@@ -0,0 +1,122 @@
package brightness
import (
"testing"
)
func TestIsIgnorableI2CDeviceName(t *testing.T) {
tests := []struct {
name string
deviceName string
driver string
want bool
}{
{
name: "AMDGPU SMU should be ignored",
deviceName: "AMDGPU SMU",
driver: "amdgpu",
want: true,
},
{
name: "SMBus should be ignored",
deviceName: "SMBus I801 adapter",
driver: "",
want: true,
},
{
name: "Synopsys DesignWare should be ignored",
deviceName: "Synopsys DesignWare I2C adapter",
driver: "",
want: true,
},
{
name: "smu prefix should be ignored (Mac G5)",
deviceName: "smu-i2c-controller",
driver: "",
want: true,
},
{
name: "Regular NVIDIA DDC should not be ignored",
deviceName: "NVIDIA i2c adapter 1",
driver: "nvidia",
want: false,
},
{
name: "nouveau nvkm bus should not be ignored",
deviceName: "nvkm-0000:01:00.0-bus-0000",
driver: "nouveau",
want: false,
},
{
name: "nouveau non-nvkm bus should be ignored",
deviceName: "nouveau-other-bus",
driver: "nouveau",
want: true,
},
{
name: "Regular AMD display adapter should not be ignored",
deviceName: "AMDGPU DM i2c hw bus 0",
driver: "amdgpu",
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := isIgnorableI2CDeviceName(tt.deviceName, tt.driver)
if got != tt.want {
t.Errorf("isIgnorableI2CDeviceName(%q, %q) = %v, want %v",
tt.deviceName, tt.driver, got, tt.want)
}
})
}
}
func TestClassFiltering(t *testing.T) {
tests := []struct {
name string
class uint32
want bool
}{
{
name: "Display adapter class should not be ignored",
class: 0x030000,
want: false,
},
{
name: "Docking station class should not be ignored",
class: 0x0a0000,
want: false,
},
{
name: "Display adapter with subclass should not be ignored",
class: 0x030001,
want: false,
},
{
name: "SMBus class should be ignored",
class: 0x0c0500,
want: true,
},
{
name: "Bridge class should be ignored",
class: 0x060400,
want: true,
},
{
name: "Generic system peripheral should be ignored",
class: 0x088000,
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
classHigh := tt.class & 0xFFFF0000
ignorable := (classHigh != 0x030000 && classHigh != 0x0A0000)
if ignorable != tt.want {
t.Errorf("class 0x%08x: ignorable = %v, want %v", tt.class, ignorable, tt.want)
}
})
}
}

View File

@@ -0,0 +1,135 @@
package brightness
import (
"testing"
)
func TestDDCBackend_PercentConversions(t *testing.T) {
tests := []struct {
name string
max int
percent int
wantValue int
}{
{
name: "0% should map to minValue=1",
max: 100,
percent: 0,
wantValue: 1,
},
{
name: "1% should be 1",
max: 100,
percent: 1,
wantValue: 1,
},
{
name: "50% should be ~50",
max: 100,
percent: 50,
wantValue: 50,
},
{
name: "100% should be max",
max: 100,
percent: 100,
wantValue: 100,
},
}
b := &DDCBackend{}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := b.percentToValue(tt.percent, tt.max, false)
diff := got - tt.wantValue
if diff < 0 {
diff = -diff
}
if diff > 1 {
t.Errorf("percentToValue() = %v, want %v (±1)", got, tt.wantValue)
}
})
}
}
func TestDDCBackend_ValueToPercent(t *testing.T) {
tests := []struct {
name string
max int
value int
wantPercent int
tolerance int
}{
{
name: "zero value should be 1%",
max: 100,
value: 0,
wantPercent: 1,
tolerance: 0,
},
{
name: "min value should be 1%",
max: 100,
value: 1,
wantPercent: 1,
tolerance: 0,
},
{
name: "mid value should be ~50%",
max: 100,
value: 50,
wantPercent: 50,
tolerance: 2,
},
{
name: "max value should be 100%",
max: 100,
value: 100,
wantPercent: 100,
tolerance: 0,
},
}
b := &DDCBackend{}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := b.valueToPercent(tt.value, tt.max, false)
diff := got - tt.wantPercent
if diff < 0 {
diff = -diff
}
if diff > tt.tolerance {
t.Errorf("valueToPercent() = %v, want %v (±%d)", got, tt.wantPercent, tt.tolerance)
}
})
}
}
func TestDDCBackend_RoundTrip(t *testing.T) {
b := &DDCBackend{}
tests := []struct {
name string
max int
percent int
}{
{"1%", 100, 1},
{"25%", 100, 25},
{"50%", 100, 50},
{"75%", 100, 75},
{"100%", 100, 100},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
value := b.percentToValue(tt.percent, tt.max, false)
gotPercent := b.valueToPercent(value, tt.max, false)
if diff := tt.percent - gotPercent; diff < -1 || diff > 1 {
t.Errorf("round trip failed: wanted %d%%, got %d%% (value=%d)", tt.percent, gotPercent, value)
}
})
}
}

View File

@@ -0,0 +1,163 @@
package brightness
import (
"encoding/json"
"net"
"github.com/AvengeMedia/danklinux/internal/server/models"
)
func HandleRequest(conn net.Conn, req Request, m *Manager) {
switch req.Method {
case "brightness.getState":
handleGetState(conn, req, m)
case "brightness.setBrightness":
handleSetBrightness(conn, req, m)
case "brightness.increment":
handleIncrement(conn, req, m)
case "brightness.decrement":
handleDecrement(conn, req, m)
case "brightness.rescan":
handleRescan(conn, req, m)
case "brightness.subscribe":
handleSubscribe(conn, req, m)
default:
models.RespondError(conn, req.ID.(int), "unknown method: "+req.Method)
}
}
func handleGetState(conn net.Conn, req Request, m *Manager) {
state := m.GetState()
models.Respond(conn, req.ID.(int), state)
}
func handleSetBrightness(conn net.Conn, req Request, m *Manager) {
var params SetBrightnessParams
device, ok := req.Params["device"].(string)
if !ok {
models.RespondError(conn, req.ID.(int), "missing or invalid device parameter")
return
}
params.Device = device
percentFloat, ok := req.Params["percent"].(float64)
if !ok {
models.RespondError(conn, req.ID.(int), "missing or invalid percent parameter")
return
}
params.Percent = int(percentFloat)
if exponential, ok := req.Params["exponential"].(bool); ok {
params.Exponential = exponential
}
exponent := 1.2
if exponentFloat, ok := req.Params["exponent"].(float64); ok {
params.Exponent = exponentFloat
exponent = exponentFloat
}
if err := m.SetBrightnessWithExponent(params.Device, params.Percent, params.Exponential, exponent); err != nil {
models.RespondError(conn, req.ID.(int), err.Error())
return
}
state := m.GetState()
models.Respond(conn, req.ID.(int), state)
}
func handleIncrement(conn net.Conn, req Request, m *Manager) {
device, ok := req.Params["device"].(string)
if !ok {
models.RespondError(conn, req.ID.(int), "missing or invalid device parameter")
return
}
step := 10
if stepFloat, ok := req.Params["step"].(float64); ok {
step = int(stepFloat)
}
exponential := false
if expBool, ok := req.Params["exponential"].(bool); ok {
exponential = expBool
}
exponent := 1.2
if exponentFloat, ok := req.Params["exponent"].(float64); ok {
exponent = exponentFloat
}
if err := m.IncrementBrightnessWithExponent(device, step, exponential, exponent); err != nil {
models.RespondError(conn, req.ID.(int), err.Error())
return
}
state := m.GetState()
models.Respond(conn, req.ID.(int), state)
}
func handleDecrement(conn net.Conn, req Request, m *Manager) {
device, ok := req.Params["device"].(string)
if !ok {
models.RespondError(conn, req.ID.(int), "missing or invalid device parameter")
return
}
step := 10
if stepFloat, ok := req.Params["step"].(float64); ok {
step = int(stepFloat)
}
exponential := false
if expBool, ok := req.Params["exponential"].(bool); ok {
exponential = expBool
}
exponent := 1.2
if exponentFloat, ok := req.Params["exponent"].(float64); ok {
exponent = exponentFloat
}
if err := m.IncrementBrightnessWithExponent(device, -step, exponential, exponent); err != nil {
models.RespondError(conn, req.ID.(int), err.Error())
return
}
state := m.GetState()
models.Respond(conn, req.ID.(int), state)
}
func handleRescan(conn net.Conn, req Request, m *Manager) {
m.Rescan()
state := m.GetState()
models.Respond(conn, req.ID.(int), state)
}
func handleSubscribe(conn net.Conn, req Request, m *Manager) {
clientID := "brightness-subscriber"
if idStr, ok := req.ID.(string); ok && idStr != "" {
clientID = idStr
}
ch := m.Subscribe(clientID)
defer m.Unsubscribe(clientID)
initialState := m.GetState()
if err := json.NewEncoder(conn).Encode(models.Response[State]{
ID: req.ID.(int),
Result: &initialState,
}); err != nil {
return
}
for state := range ch {
if err := json.NewEncoder(conn).Encode(models.Response[State]{
ID: req.ID.(int),
Result: &state,
}); err != nil {
return
}
}
}

View File

@@ -0,0 +1,67 @@
package brightness
import (
"fmt"
"github.com/AvengeMedia/danklinux/internal/log"
"github.com/godbus/dbus/v5"
)
type DBusConn interface {
Object(dest string, path dbus.ObjectPath) dbus.BusObject
Close() error
}
type LogindBackend struct {
conn DBusConn
connOnce bool
}
func NewLogindBackend() (*LogindBackend, error) {
conn, err := dbus.ConnectSystemBus()
if err != nil {
return nil, fmt.Errorf("connect to system bus: %w", err)
}
obj := conn.Object("org.freedesktop.login1", "/org/freedesktop/login1/session/auto")
call := obj.Call("org.freedesktop.DBus.Peer.Ping", 0)
if call.Err != nil {
conn.Close()
return nil, fmt.Errorf("logind not available: %w", call.Err)
}
conn.Close()
return &LogindBackend{}, nil
}
func NewLogindBackendWithConn(conn DBusConn) *LogindBackend {
return &LogindBackend{
conn: conn,
}
}
func (b *LogindBackend) SetBrightness(subsystem, name string, brightness uint32) error {
if b.conn == nil {
conn, err := dbus.ConnectSystemBus()
if err != nil {
return fmt.Errorf("connect to system bus: %w", err)
}
b.conn = conn
}
obj := b.conn.Object("org.freedesktop.login1", "/org/freedesktop/login1/session/auto")
call := obj.Call("org.freedesktop.login1.Session.SetBrightness", 0, subsystem, name, brightness)
if call.Err != nil {
return fmt.Errorf("dbus call failed: %w", call.Err)
}
log.Debugf("logind: set %s/%s to %d", subsystem, name, brightness)
return nil
}
func (b *LogindBackend) Close() {
if b.conn != nil {
b.conn.Close()
}
}

View File

@@ -0,0 +1,95 @@
package brightness
import (
"errors"
"testing"
mocks_brightness "github.com/AvengeMedia/danklinux/internal/mocks/brightness"
mock_dbus "github.com/AvengeMedia/danklinux/internal/mocks/github.com/godbus/dbus/v5"
"github.com/godbus/dbus/v5"
"github.com/stretchr/testify/mock"
)
func TestLogindBackend_SetBrightness_Success(t *testing.T) {
mockConn := mocks_brightness.NewMockDBusConn(t)
mockObj := mock_dbus.NewMockBusObject(t)
backend := NewLogindBackendWithConn(mockConn)
mockConn.EXPECT().
Object("org.freedesktop.login1", dbus.ObjectPath("/org/freedesktop/login1/session/auto")).
Return(mockObj).
Once()
mockObj.EXPECT().
Call("org.freedesktop.login1.Session.SetBrightness", dbus.Flags(0), "backlight", "nvidia_0", uint32(75)).
Return(&dbus.Call{Err: nil}).
Once()
err := backend.SetBrightness("backlight", "nvidia_0", 75)
if err != nil {
t.Errorf("SetBrightness() error = %v, want nil", err)
}
}
func TestLogindBackend_SetBrightness_DBusError(t *testing.T) {
mockConn := mocks_brightness.NewMockDBusConn(t)
mockObj := mock_dbus.NewMockBusObject(t)
backend := NewLogindBackendWithConn(mockConn)
mockConn.EXPECT().
Object("org.freedesktop.login1", dbus.ObjectPath("/org/freedesktop/login1/session/auto")).
Return(mockObj).
Once()
dbusErr := errors.New("permission denied")
mockObj.EXPECT().
Call("org.freedesktop.login1.Session.SetBrightness", mock.Anything, mock.Anything, mock.Anything, mock.Anything).
Return(&dbus.Call{Err: dbusErr}).
Once()
err := backend.SetBrightness("backlight", "test_device", 50)
if err == nil {
t.Error("SetBrightness() error = nil, want error")
}
}
func TestLogindBackend_SetBrightness_LEDDevice(t *testing.T) {
mockConn := mocks_brightness.NewMockDBusConn(t)
mockObj := mock_dbus.NewMockBusObject(t)
backend := NewLogindBackendWithConn(mockConn)
mockConn.EXPECT().
Object("org.freedesktop.login1", dbus.ObjectPath("/org/freedesktop/login1/session/auto")).
Return(mockObj).
Once()
mockObj.EXPECT().
Call("org.freedesktop.login1.Session.SetBrightness", dbus.Flags(0), "leds", "test_led", uint32(128)).
Return(&dbus.Call{Err: nil}).
Once()
err := backend.SetBrightness("leds", "test_led", 128)
if err != nil {
t.Errorf("SetBrightness() error = %v, want nil", err)
}
}
func TestLogindBackend_Close(t *testing.T) {
mockConn := mocks_brightness.NewMockDBusConn(t)
backend := NewLogindBackendWithConn(mockConn)
mockConn.EXPECT().
Close().
Return(nil).
Once()
backend.Close()
}
func TestLogindBackend_Close_NilConn(t *testing.T) {
backend := &LogindBackend{conn: nil}
backend.Close()
}

View File

@@ -0,0 +1,383 @@
package brightness
import (
"fmt"
"sort"
"strings"
"time"
"github.com/AvengeMedia/danklinux/internal/log"
)
func NewManager() (*Manager, error) {
return NewManagerWithOptions(false)
}
func NewManagerWithOptions(exponential bool) (*Manager, error) {
m := &Manager{
subscribers: make(map[string]chan State),
updateSubscribers: make(map[string]chan DeviceUpdate),
stopChan: make(chan struct{}),
exponential: exponential,
}
go m.initLogind()
go m.initSysfs()
go m.initDDC()
return m, nil
}
func (m *Manager) initLogind() {
log.Debug("Initializing logind backend...")
logind, err := NewLogindBackend()
if err != nil {
log.Infof("Logind backend not available: %v", err)
log.Info("Will use direct sysfs access for brightness control")
return
}
m.logindBackend = logind
m.logindReady = true
log.Info("Logind backend initialized - will use for brightness control")
}
func (m *Manager) initSysfs() {
log.Debug("Initializing sysfs backend...")
sysfs, err := NewSysfsBackend()
if err != nil {
log.Warnf("Failed to initialize sysfs backend: %v", err)
return
}
devices, err := sysfs.GetDevices()
if err != nil {
log.Warnf("Failed to get initial sysfs devices: %v", err)
m.sysfsBackend = sysfs
m.sysfsReady = true
m.updateState()
return
}
log.Infof("Sysfs backend initialized with %d devices", len(devices))
for _, d := range devices {
log.Debugf(" - %s: %s (%d%%)", d.ID, d.Name, d.CurrentPercent)
}
m.sysfsBackend = sysfs
m.sysfsReady = true
m.updateState()
}
func (m *Manager) initDDC() {
ddc, err := NewDDCBackend()
if err != nil {
log.Debugf("Failed to initialize DDC backend: %v", err)
return
}
m.ddcBackend = ddc
m.ddcReady = true
log.Info("DDC backend initialized")
m.updateState()
}
func (m *Manager) Rescan() {
log.Debug("Rescanning brightness devices...")
m.updateState()
}
func sortDevices(devices []Device) {
sort.Slice(devices, func(i, j int) bool {
classOrder := map[DeviceClass]int{
ClassBacklight: 0,
ClassDDC: 1,
ClassLED: 2,
}
orderI := classOrder[devices[i].Class]
orderJ := classOrder[devices[j].Class]
if orderI != orderJ {
return orderI < orderJ
}
return devices[i].Name < devices[j].Name
})
}
func stateChanged(old, new State) bool {
if len(old.Devices) != len(new.Devices) {
return true
}
oldMap := make(map[string]Device)
for _, d := range old.Devices {
oldMap[d.ID] = d
}
for _, newDev := range new.Devices {
oldDev, exists := oldMap[newDev.ID]
if !exists {
return true
}
if oldDev.Current != newDev.Current || oldDev.Max != newDev.Max {
return true
}
}
return false
}
func (m *Manager) updateState() {
allDevices := make([]Device, 0)
if m.sysfsReady && m.sysfsBackend != nil {
devices, err := m.sysfsBackend.GetDevices()
if err != nil {
log.Debugf("Failed to get sysfs devices: %v", err)
}
if err == nil {
allDevices = append(allDevices, devices...)
}
}
if m.ddcReady && m.ddcBackend != nil {
devices, err := m.ddcBackend.GetDevices()
if err != nil {
log.Debugf("Failed to get DDC devices: %v", err)
}
if err == nil {
allDevices = append(allDevices, devices...)
}
}
sortDevices(allDevices)
m.stateMutex.Lock()
oldState := m.state
newState := State{Devices: allDevices}
if !stateChanged(oldState, newState) {
m.stateMutex.Unlock()
return
}
m.state = newState
m.stateMutex.Unlock()
log.Debugf("State changed, notifying subscribers")
m.NotifySubscribers()
}
func (m *Manager) SetBrightness(deviceID string, percent int) error {
return m.SetBrightnessWithMode(deviceID, percent, m.exponential)
}
func (m *Manager) SetBrightnessWithMode(deviceID string, percent int, exponential bool) error {
return m.SetBrightnessWithExponent(deviceID, percent, exponential, 1.2)
}
func (m *Manager) SetBrightnessWithExponent(deviceID string, percent int, exponential bool, exponent float64) error {
if percent < 0 {
return fmt.Errorf("percent out of range: %d", percent)
}
log.Debugf("SetBrightness: %s to %d%%", deviceID, percent)
m.stateMutex.Lock()
currentState := m.state
var found bool
var deviceClass DeviceClass
var deviceIndex int
log.Debugf("Current state has %d devices", len(currentState.Devices))
for i, dev := range currentState.Devices {
if dev.ID == deviceID {
found = true
deviceClass = dev.Class
deviceIndex = i
break
}
}
if !found {
m.stateMutex.Unlock()
log.Debugf("Device not found in state: %s", deviceID)
return fmt.Errorf("device not found: %s", deviceID)
}
newDevices := make([]Device, len(currentState.Devices))
copy(newDevices, currentState.Devices)
newDevices[deviceIndex].CurrentPercent = percent
m.state = State{Devices: newDevices}
m.stateMutex.Unlock()
var err error
if deviceClass == ClassDDC {
log.Debugf("Calling DDC backend for %s", deviceID)
err = m.ddcBackend.SetBrightnessWithExponent(deviceID, percent, exponential, exponent, func() {
m.updateState()
m.debouncedBroadcast(deviceID)
})
} else if m.logindReady && m.logindBackend != nil {
log.Debugf("Calling logind backend for %s", deviceID)
err = m.setViaSysfsWithLogindWithExponent(deviceID, percent, exponential, exponent)
} else {
log.Debugf("Calling sysfs backend for %s", deviceID)
err = m.sysfsBackend.SetBrightnessWithExponent(deviceID, percent, exponential, exponent)
}
if err != nil {
m.updateState()
return fmt.Errorf("failed to set brightness: %w", err)
}
if deviceClass != ClassDDC {
log.Debugf("Queueing broadcast for %s", deviceID)
m.debouncedBroadcast(deviceID)
}
return nil
}
func (m *Manager) IncrementBrightness(deviceID string, step int) error {
return m.IncrementBrightnessWithMode(deviceID, step, m.exponential)
}
func (m *Manager) IncrementBrightnessWithMode(deviceID string, step int, exponential bool) error {
return m.IncrementBrightnessWithExponent(deviceID, step, exponential, 1.2)
}
func (m *Manager) IncrementBrightnessWithExponent(deviceID string, step int, exponential bool, exponent float64) error {
m.stateMutex.RLock()
currentState := m.state
m.stateMutex.RUnlock()
var currentPercent int
var found bool
for _, dev := range currentState.Devices {
if dev.ID == deviceID {
currentPercent = dev.CurrentPercent
found = true
break
}
}
if !found {
return fmt.Errorf("device not found: %s", deviceID)
}
newPercent := currentPercent + step
if newPercent > 100 {
newPercent = 100
}
if newPercent < 0 {
newPercent = 0
}
return m.SetBrightnessWithExponent(deviceID, newPercent, exponential, exponent)
}
func (m *Manager) DecrementBrightness(deviceID string, step int) error {
return m.IncrementBrightness(deviceID, -step)
}
func (m *Manager) setViaSysfsWithLogind(deviceID string, percent int, exponential bool) error {
return m.setViaSysfsWithLogindWithExponent(deviceID, percent, exponential, 1.2)
}
func (m *Manager) setViaSysfsWithLogindWithExponent(deviceID string, percent int, exponential bool, exponent float64) error {
parts := strings.SplitN(deviceID, ":", 2)
if len(parts) != 2 {
return fmt.Errorf("invalid device id: %s", deviceID)
}
subsystem := parts[0]
name := parts[1]
dev, err := m.sysfsBackend.GetDevice(deviceID)
if err != nil {
return err
}
value := m.sysfsBackend.PercentToValueWithExponent(percent, dev, exponential, exponent)
if m.logindBackend == nil {
return m.sysfsBackend.SetBrightnessWithExponent(deviceID, percent, exponential, exponent)
}
err = m.logindBackend.SetBrightness(subsystem, name, uint32(value))
if err != nil {
log.Debugf("logind SetBrightness failed, falling back to direct sysfs: %v", err)
return m.sysfsBackend.SetBrightnessWithExponent(deviceID, percent, exponential, exponent)
}
log.Debugf("set %s to %d%% (%d/%d) via logind", deviceID, percent, value, dev.maxBrightness)
return nil
}
func (m *Manager) debouncedBroadcast(deviceID string) {
m.broadcastMutex.Lock()
defer m.broadcastMutex.Unlock()
m.broadcastPending = true
m.pendingDeviceID = deviceID
if m.broadcastTimer == nil {
m.broadcastTimer = time.AfterFunc(150*time.Millisecond, func() {
m.broadcastMutex.Lock()
pending := m.broadcastPending
deviceID := m.pendingDeviceID
m.broadcastPending = false
m.pendingDeviceID = ""
m.broadcastMutex.Unlock()
if !pending || deviceID == "" {
return
}
m.broadcastDeviceUpdate(deviceID)
})
} else {
m.broadcastTimer.Reset(150 * time.Millisecond)
}
}
func (m *Manager) broadcastDeviceUpdate(deviceID string) {
m.stateMutex.RLock()
var targetDevice *Device
for _, dev := range m.state.Devices {
if dev.ID == deviceID {
devCopy := dev
targetDevice = &devCopy
break
}
}
m.stateMutex.RUnlock()
if targetDevice == nil {
log.Debugf("Device not found for broadcast: %s", deviceID)
return
}
update := DeviceUpdate{Device: *targetDevice}
m.subMutex.RLock()
defer m.subMutex.RUnlock()
if len(m.updateSubscribers) == 0 {
log.Debugf("No update subscribers for device: %s", deviceID)
return
}
log.Debugf("Broadcasting device update: %s at %d%%", deviceID, targetDevice.CurrentPercent)
for _, ch := range m.updateSubscribers {
select {
case ch <- update:
default:
}
}
}

View File

@@ -0,0 +1,11 @@
package brightness
import (
"testing"
)
// Manager tests can be added here as needed
func TestManager_Placeholder(t *testing.T) {
// Placeholder test to keep the test file valid
t.Skip("No tests implemented yet")
}

View File

@@ -0,0 +1,272 @@
package brightness
import (
"fmt"
"math"
"os"
"path/filepath"
"strconv"
"strings"
"github.com/AvengeMedia/danklinux/internal/log"
)
func NewSysfsBackend() (*SysfsBackend, error) {
b := &SysfsBackend{
basePath: "/sys/class",
classes: []string{"backlight", "leds"},
deviceCache: make(map[string]*sysfsDevice),
}
if err := b.scanDevices(); err != nil {
return nil, err
}
return b, nil
}
func (b *SysfsBackend) scanDevices() error {
b.deviceCacheMutex.Lock()
defer b.deviceCacheMutex.Unlock()
for _, class := range b.classes {
classPath := filepath.Join(b.basePath, class)
entries, err := os.ReadDir(classPath)
if err != nil {
if os.IsNotExist(err) {
continue
}
return fmt.Errorf("read %s: %w", classPath, err)
}
for _, entry := range entries {
devicePath := filepath.Join(classPath, entry.Name())
stat, err := os.Stat(devicePath)
if err != nil || !stat.IsDir() {
continue
}
maxPath := filepath.Join(devicePath, "max_brightness")
maxData, err := os.ReadFile(maxPath)
if err != nil {
log.Debugf("skip %s/%s: no max_brightness", class, entry.Name())
continue
}
maxBrightness, err := strconv.Atoi(strings.TrimSpace(string(maxData)))
if err != nil || maxBrightness <= 0 {
log.Debugf("skip %s/%s: invalid max_brightness", class, entry.Name())
continue
}
deviceClass := ClassBacklight
minValue := 1
if class == "leds" {
deviceClass = ClassLED
minValue = 0
}
deviceID := fmt.Sprintf("%s:%s", class, entry.Name())
b.deviceCache[deviceID] = &sysfsDevice{
class: deviceClass,
id: deviceID,
name: entry.Name(),
maxBrightness: maxBrightness,
minValue: minValue,
}
log.Debugf("found %s device: %s (max=%d)", class, entry.Name(), maxBrightness)
}
}
return nil
}
func shouldSuppressDevice(name string) bool {
if strings.HasSuffix(name, "::lan") {
return true
}
keyboardLEDs := []string{
"::scrolllock",
"::capslock",
"::numlock",
"::kana",
"::compose",
}
for _, suffix := range keyboardLEDs {
if strings.HasSuffix(name, suffix) {
return true
}
}
return false
}
func (b *SysfsBackend) GetDevices() ([]Device, error) {
b.deviceCacheMutex.RLock()
defer b.deviceCacheMutex.RUnlock()
devices := make([]Device, 0, len(b.deviceCache))
for _, dev := range b.deviceCache {
if shouldSuppressDevice(dev.name) {
continue
}
parts := strings.SplitN(dev.id, ":", 2)
if len(parts) != 2 {
continue
}
class := parts[0]
name := parts[1]
devicePath := filepath.Join(b.basePath, class, name)
brightnessPath := filepath.Join(devicePath, "brightness")
brightnessData, err := os.ReadFile(brightnessPath)
if err != nil {
log.Debugf("failed to read brightness for %s: %v", dev.id, err)
continue
}
current, err := strconv.Atoi(strings.TrimSpace(string(brightnessData)))
if err != nil {
log.Debugf("failed to parse brightness for %s: %v", dev.id, err)
continue
}
percent := b.ValueToPercent(current, dev, false)
devices = append(devices, Device{
Class: dev.class,
ID: dev.id,
Name: dev.name,
Current: current,
Max: dev.maxBrightness,
CurrentPercent: percent,
Backend: "sysfs",
})
}
return devices, nil
}
func (b *SysfsBackend) GetDevice(id string) (*sysfsDevice, error) {
b.deviceCacheMutex.RLock()
defer b.deviceCacheMutex.RUnlock()
dev, ok := b.deviceCache[id]
if !ok {
return nil, fmt.Errorf("device not found: %s", id)
}
return dev, nil
}
func (b *SysfsBackend) SetBrightness(id string, percent int, exponential bool) error {
return b.SetBrightnessWithExponent(id, percent, exponential, 1.2)
}
func (b *SysfsBackend) SetBrightnessWithExponent(id string, percent int, exponential bool, exponent float64) error {
dev, err := b.GetDevice(id)
if err != nil {
return err
}
if percent < 0 {
return fmt.Errorf("percent out of range: %d", percent)
}
value := b.PercentToValueWithExponent(percent, dev, exponential, exponent)
parts := strings.SplitN(id, ":", 2)
if len(parts) != 2 {
return fmt.Errorf("invalid device id: %s", id)
}
class := parts[0]
name := parts[1]
devicePath := filepath.Join(b.basePath, class, name)
brightnessPath := filepath.Join(devicePath, "brightness")
data := []byte(fmt.Sprintf("%d", value))
if err := os.WriteFile(brightnessPath, data, 0644); err != nil {
return fmt.Errorf("write brightness: %w", err)
}
log.Debugf("set %s to %d%% (%d/%d) via direct sysfs", id, percent, value, dev.maxBrightness)
return nil
}
func (b *SysfsBackend) PercentToValue(percent int, dev *sysfsDevice, exponential bool) int {
return b.PercentToValueWithExponent(percent, dev, exponential, 1.2)
}
func (b *SysfsBackend) PercentToValueWithExponent(percent int, dev *sysfsDevice, exponential bool, exponent float64) int {
if percent == 0 {
return dev.minValue
}
usableRange := dev.maxBrightness - dev.minValue
var value int
if exponential {
normalizedPercent := float64(percent-1) / 99.0
hardwarePercent := math.Pow(normalizedPercent, exponent)
value = dev.minValue + int(math.Round(hardwarePercent*float64(usableRange)))
} else {
value = dev.minValue + ((percent - 1) * usableRange / 99)
}
if value < dev.minValue {
value = dev.minValue
}
if value > dev.maxBrightness {
value = dev.maxBrightness
}
return value
}
func (b *SysfsBackend) ValueToPercent(value int, dev *sysfsDevice, exponential bool) int {
return b.ValueToPercentWithExponent(value, dev, exponential, 1.2)
}
func (b *SysfsBackend) ValueToPercentWithExponent(value int, dev *sysfsDevice, exponential bool, exponent float64) int {
if value <= dev.minValue {
if dev.minValue == 0 && value == 0 {
return 0
}
return 1
}
usableRange := dev.maxBrightness - dev.minValue
if usableRange == 0 {
return 100
}
var percent int
if exponential {
hardwarePercent := float64(value-dev.minValue) / float64(usableRange)
normalizedPercent := math.Pow(hardwarePercent, 1.0/exponent)
percent = 1 + int(math.Round(normalizedPercent*99.0))
} else {
percent = 1 + int(math.Round(float64(value-dev.minValue)*99.0/float64(usableRange)))
}
if percent > 100 {
percent = 100
}
if percent < 1 {
percent = 1
}
return percent
}

View File

@@ -0,0 +1,290 @@
package brightness
import (
"os"
"path/filepath"
"testing"
mocks_brightness "github.com/AvengeMedia/danklinux/internal/mocks/brightness"
mock_dbus "github.com/AvengeMedia/danklinux/internal/mocks/github.com/godbus/dbus/v5"
"github.com/godbus/dbus/v5"
"github.com/stretchr/testify/mock"
)
func TestManager_SetBrightness_LogindSuccess(t *testing.T) {
tmpDir := t.TempDir()
backlightDir := filepath.Join(tmpDir, "backlight", "test_backlight")
if err := os.MkdirAll(backlightDir, 0755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(backlightDir, "max_brightness"), []byte("100\n"), 0644); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(backlightDir, "brightness"), []byte("50\n"), 0644); err != nil {
t.Fatal(err)
}
mockConn := mocks_brightness.NewMockDBusConn(t)
mockObj := mock_dbus.NewMockBusObject(t)
mockLogind := NewLogindBackendWithConn(mockConn)
sysfs := &SysfsBackend{
basePath: tmpDir,
classes: []string{"backlight"},
deviceCache: make(map[string]*sysfsDevice),
}
if err := sysfs.scanDevices(); err != nil {
t.Fatal(err)
}
m := &Manager{
logindBackend: mockLogind,
sysfsBackend: sysfs,
logindReady: true,
sysfsReady: true,
subscribers: make(map[string]chan State),
updateSubscribers: make(map[string]chan DeviceUpdate),
stopChan: make(chan struct{}),
}
m.state = State{
Devices: []Device{
{
Class: ClassBacklight,
ID: "backlight:test_backlight",
Name: "test_backlight",
Current: 50,
Max: 100,
CurrentPercent: 50,
Backend: "sysfs",
},
},
}
mockConn.EXPECT().
Object("org.freedesktop.login1", dbus.ObjectPath("/org/freedesktop/login1/session/auto")).
Return(mockObj).
Once()
mockObj.EXPECT().
Call("org.freedesktop.login1.Session.SetBrightness", mock.Anything, "backlight", "test_backlight", uint32(75)).
Return(&dbus.Call{Err: nil}).
Once()
err := m.SetBrightness("backlight:test_backlight", 75)
if err != nil {
t.Errorf("SetBrightness() with logind error = %v, want nil", err)
}
data, _ := os.ReadFile(filepath.Join(backlightDir, "brightness"))
if string(data) == "75\n" {
t.Error("Direct sysfs write occurred when logind should have been used")
}
}
func TestManager_SetBrightness_LogindFailsFallbackToSysfs(t *testing.T) {
tmpDir := t.TempDir()
backlightDir := filepath.Join(tmpDir, "backlight", "test_backlight")
if err := os.MkdirAll(backlightDir, 0755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(backlightDir, "max_brightness"), []byte("100\n"), 0644); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(backlightDir, "brightness"), []byte("50\n"), 0644); err != nil {
t.Fatal(err)
}
mockConn := mocks_brightness.NewMockDBusConn(t)
mockObj := mock_dbus.NewMockBusObject(t)
mockLogind := NewLogindBackendWithConn(mockConn)
sysfs := &SysfsBackend{
basePath: tmpDir,
classes: []string{"backlight"},
deviceCache: make(map[string]*sysfsDevice),
}
if err := sysfs.scanDevices(); err != nil {
t.Fatal(err)
}
m := &Manager{
logindBackend: mockLogind,
sysfsBackend: sysfs,
logindReady: true,
sysfsReady: true,
subscribers: make(map[string]chan State),
updateSubscribers: make(map[string]chan DeviceUpdate),
stopChan: make(chan struct{}),
}
m.state = State{
Devices: []Device{
{
Class: ClassBacklight,
ID: "backlight:test_backlight",
Name: "test_backlight",
Current: 50,
Max: 100,
CurrentPercent: 50,
Backend: "sysfs",
},
},
}
mockConn.EXPECT().
Object("org.freedesktop.login1", dbus.ObjectPath("/org/freedesktop/login1/session/auto")).
Return(mockObj).
Once()
mockObj.EXPECT().
Call("org.freedesktop.login1.Session.SetBrightness", mock.Anything, "backlight", "test_backlight", mock.Anything).
Return(&dbus.Call{Err: dbus.ErrMsgNoObject}).
Once()
err := m.SetBrightness("backlight:test_backlight", 75)
if err != nil {
t.Errorf("SetBrightness() with fallback error = %v, want nil", err)
}
data, _ := os.ReadFile(filepath.Join(backlightDir, "brightness"))
brightness := string(data)
if brightness != "75" {
t.Errorf("Fallback sysfs write did not occur, got brightness = %q, want %q", brightness, "75")
}
}
func TestManager_SetBrightness_NoLogind(t *testing.T) {
tmpDir := t.TempDir()
backlightDir := filepath.Join(tmpDir, "backlight", "test_backlight")
if err := os.MkdirAll(backlightDir, 0755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(backlightDir, "max_brightness"), []byte("100\n"), 0644); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(backlightDir, "brightness"), []byte("50\n"), 0644); err != nil {
t.Fatal(err)
}
sysfs := &SysfsBackend{
basePath: tmpDir,
classes: []string{"backlight"},
deviceCache: make(map[string]*sysfsDevice),
}
if err := sysfs.scanDevices(); err != nil {
t.Fatal(err)
}
m := &Manager{
logindBackend: nil,
sysfsBackend: sysfs,
logindReady: false,
sysfsReady: true,
subscribers: make(map[string]chan State),
updateSubscribers: make(map[string]chan DeviceUpdate),
stopChan: make(chan struct{}),
}
m.state = State{
Devices: []Device{
{
Class: ClassBacklight,
ID: "backlight:test_backlight",
Name: "test_backlight",
Current: 50,
Max: 100,
CurrentPercent: 50,
Backend: "sysfs",
},
},
}
err := m.SetBrightness("backlight:test_backlight", 75)
if err != nil {
t.Errorf("SetBrightness() without logind error = %v, want nil", err)
}
data, _ := os.ReadFile(filepath.Join(backlightDir, "brightness"))
brightness := string(data)
if brightness != "75" {
t.Errorf("Direct sysfs write = %q, want %q", brightness, "75")
}
}
func TestManager_SetBrightness_LEDWithLogind(t *testing.T) {
tmpDir := t.TempDir()
ledsDir := filepath.Join(tmpDir, "leds", "test_led")
if err := os.MkdirAll(ledsDir, 0755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(ledsDir, "max_brightness"), []byte("255\n"), 0644); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(ledsDir, "brightness"), []byte("128\n"), 0644); err != nil {
t.Fatal(err)
}
mockConn := mocks_brightness.NewMockDBusConn(t)
mockObj := mock_dbus.NewMockBusObject(t)
mockLogind := NewLogindBackendWithConn(mockConn)
sysfs := &SysfsBackend{
basePath: tmpDir,
classes: []string{"leds"},
deviceCache: make(map[string]*sysfsDevice),
}
if err := sysfs.scanDevices(); err != nil {
t.Fatal(err)
}
m := &Manager{
logindBackend: mockLogind,
sysfsBackend: sysfs,
logindReady: true,
sysfsReady: true,
subscribers: make(map[string]chan State),
updateSubscribers: make(map[string]chan DeviceUpdate),
stopChan: make(chan struct{}),
}
m.state = State{
Devices: []Device{
{
Class: ClassLED,
ID: "leds:test_led",
Name: "test_led",
Current: 128,
Max: 255,
CurrentPercent: 50,
Backend: "sysfs",
},
},
}
mockConn.EXPECT().
Object("org.freedesktop.login1", dbus.ObjectPath("/org/freedesktop/login1/session/auto")).
Return(mockObj).
Once()
mockObj.EXPECT().
Call("org.freedesktop.login1.Session.SetBrightness", mock.Anything, "leds", "test_led", uint32(0)).
Return(&dbus.Call{Err: nil}).
Once()
err := m.SetBrightness("leds:test_led", 0)
if err != nil {
t.Errorf("SetBrightness() LED with logind error = %v, want nil", err)
}
}

View File

@@ -0,0 +1,185 @@
package brightness
import (
"os"
"path/filepath"
"testing"
)
func TestSysfsBackend_PercentConversions(t *testing.T) {
tests := []struct {
name string
device *sysfsDevice
percent int
wantValue int
tolerance int
}{
{
name: "backlight 0% should be minValue=1",
device: &sysfsDevice{maxBrightness: 100, minValue: 1, class: ClassBacklight},
percent: 0,
wantValue: 1,
tolerance: 0,
},
{
name: "backlight 1% should be minValue=1",
device: &sysfsDevice{maxBrightness: 100, minValue: 1, class: ClassBacklight},
percent: 1,
wantValue: 1,
tolerance: 0,
},
{
name: "backlight 50% should be ~50",
device: &sysfsDevice{maxBrightness: 100, minValue: 1, class: ClassBacklight},
percent: 50,
wantValue: 50,
tolerance: 1,
},
{
name: "backlight 100% should be max",
device: &sysfsDevice{maxBrightness: 100, minValue: 1, class: ClassBacklight},
percent: 100,
wantValue: 100,
tolerance: 0,
},
{
name: "led 0% should be 0",
device: &sysfsDevice{maxBrightness: 255, minValue: 0, class: ClassLED},
percent: 0,
wantValue: 0,
tolerance: 0,
},
{
name: "led 1% should be ~2-3",
device: &sysfsDevice{maxBrightness: 255, minValue: 0, class: ClassLED},
percent: 1,
wantValue: 2,
tolerance: 3,
},
{
name: "led 50% should be ~127",
device: &sysfsDevice{maxBrightness: 255, minValue: 0, class: ClassLED},
percent: 50,
wantValue: 127,
tolerance: 2,
},
{
name: "led 100% should be max",
device: &sysfsDevice{maxBrightness: 255, minValue: 0, class: ClassLED},
percent: 100,
wantValue: 255,
tolerance: 0,
},
}
b := &SysfsBackend{}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := b.PercentToValue(tt.percent, tt.device, false)
diff := got - tt.wantValue
if diff < 0 {
diff = -diff
}
if diff > tt.tolerance {
t.Errorf("percentToValue() = %v, want %v (±%d)", got, tt.wantValue, tt.tolerance)
}
gotPercent := b.ValueToPercent(got, tt.device, false)
if tt.percent > 1 && gotPercent == 0 {
t.Errorf("valueToPercent() returned 0 for non-zero input (percent=%d, got value=%d)", tt.percent, got)
}
})
}
}
func TestSysfsBackend_ValueToPercent(t *testing.T) {
tests := []struct {
name string
device *sysfsDevice
value int
wantPercent int
}{
{
name: "backlight min value",
device: &sysfsDevice{maxBrightness: 100, minValue: 1, class: ClassBacklight},
value: 1,
wantPercent: 1,
},
{
name: "backlight max value",
device: &sysfsDevice{maxBrightness: 100, minValue: 1, class: ClassBacklight},
value: 100,
wantPercent: 100,
},
{
name: "led zero",
device: &sysfsDevice{maxBrightness: 255, minValue: 0, class: ClassLED},
value: 0,
wantPercent: 0,
},
}
b := &SysfsBackend{}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := b.ValueToPercent(tt.value, tt.device, false)
if got != tt.wantPercent {
t.Errorf("valueToPercent() = %v, want %v", got, tt.wantPercent)
}
})
}
}
func TestSysfsBackend_ScanDevices(t *testing.T) {
tmpDir := t.TempDir()
backlightDir := filepath.Join(tmpDir, "backlight", "test_backlight")
if err := os.MkdirAll(backlightDir, 0755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(backlightDir, "max_brightness"), []byte("100\n"), 0644); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(backlightDir, "brightness"), []byte("50\n"), 0644); err != nil {
t.Fatal(err)
}
ledsDir := filepath.Join(tmpDir, "leds", "test_led")
if err := os.MkdirAll(ledsDir, 0755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(ledsDir, "max_brightness"), []byte("255\n"), 0644); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(ledsDir, "brightness"), []byte("128\n"), 0644); err != nil {
t.Fatal(err)
}
b := &SysfsBackend{
basePath: tmpDir,
classes: []string{"backlight", "leds"},
deviceCache: make(map[string]*sysfsDevice),
}
if err := b.scanDevices(); err != nil {
t.Fatalf("scanDevices() error = %v", err)
}
if len(b.deviceCache) != 2 {
t.Errorf("expected 2 devices, got %d", len(b.deviceCache))
}
backlightID := "backlight:test_backlight"
if _, ok := b.deviceCache[backlightID]; !ok {
t.Errorf("backlight device not found")
}
ledID := "leds:test_led"
if _, ok := b.deviceCache[ledID]; !ok {
t.Errorf("LED device not found")
}
}

View File

@@ -0,0 +1,199 @@
package brightness
import (
"sync"
"time"
)
type DeviceClass string
const (
ClassBacklight DeviceClass = "backlight"
ClassLED DeviceClass = "leds"
ClassDDC DeviceClass = "ddc"
)
type Device struct {
Class DeviceClass `json:"class"`
ID string `json:"id"`
Name string `json:"name"`
Current int `json:"current"`
Max int `json:"max"`
CurrentPercent int `json:"currentPercent"`
Backend string `json:"backend"`
}
type State struct {
Devices []Device `json:"devices"`
}
type DeviceUpdate struct {
Device Device `json:"device"`
}
type Request struct {
ID interface{} `json:"id"`
Method string `json:"method"`
Params map[string]interface{} `json:"params"`
}
type Manager struct {
logindBackend *LogindBackend
sysfsBackend *SysfsBackend
ddcBackend *DDCBackend
logindReady bool
sysfsReady bool
ddcReady bool
exponential bool
stateMutex sync.RWMutex
state State
subscribers map[string]chan State
updateSubscribers map[string]chan DeviceUpdate
subMutex sync.RWMutex
broadcastMutex sync.Mutex
broadcastTimer *time.Timer
broadcastPending bool
pendingDeviceID string
stopChan chan struct{}
}
type SysfsBackend struct {
basePath string
classes []string
deviceCache map[string]*sysfsDevice
deviceCacheMutex sync.RWMutex
}
type sysfsDevice struct {
class DeviceClass
id string
name string
maxBrightness int
minValue int
}
type DDCBackend struct {
devices map[string]*ddcDevice
devicesMutex sync.RWMutex
scanMutex sync.Mutex
lastScan time.Time
scanInterval time.Duration
debounceMutex sync.Mutex
debounceTimers map[string]*time.Timer
debouncePending map[string]ddcPendingSet
}
type ddcPendingSet struct {
percent int
callback func()
}
type ddcDevice struct {
bus int
addr int
id string
name string
max int
lastBrightness int
}
type ddcCapability struct {
vcp byte
max int
current int
}
type SetBrightnessParams struct {
Device string `json:"device"`
Percent int `json:"percent"`
Exponential bool `json:"exponential,omitempty"`
Exponent float64 `json:"exponent,omitempty"`
}
func (m *Manager) Subscribe(id string) chan State {
ch := make(chan State, 16)
m.subMutex.Lock()
m.subscribers[id] = ch
m.subMutex.Unlock()
return ch
}
func (m *Manager) Unsubscribe(id string) {
m.subMutex.Lock()
if ch, ok := m.subscribers[id]; ok {
close(ch)
delete(m.subscribers, id)
}
m.subMutex.Unlock()
}
func (m *Manager) SubscribeUpdates(id string) chan DeviceUpdate {
ch := make(chan DeviceUpdate, 16)
m.subMutex.Lock()
m.updateSubscribers[id] = ch
m.subMutex.Unlock()
return ch
}
func (m *Manager) UnsubscribeUpdates(id string) {
m.subMutex.Lock()
if ch, ok := m.updateSubscribers[id]; ok {
close(ch)
delete(m.updateSubscribers, id)
}
m.subMutex.Unlock()
}
func (m *Manager) NotifySubscribers() {
m.stateMutex.RLock()
state := m.state
m.stateMutex.RUnlock()
m.subMutex.RLock()
defer m.subMutex.RUnlock()
for _, ch := range m.subscribers {
select {
case ch <- state:
default:
}
}
}
func (m *Manager) GetState() State {
m.stateMutex.RLock()
defer m.stateMutex.RUnlock()
return m.state
}
func (m *Manager) Close() {
close(m.stopChan)
m.subMutex.Lock()
for _, ch := range m.subscribers {
close(ch)
}
m.subscribers = make(map[string]chan State)
for _, ch := range m.updateSubscribers {
close(ch)
}
m.updateSubscribers = make(map[string]chan DeviceUpdate)
m.subMutex.Unlock()
if m.logindBackend != nil {
m.logindBackend.Close()
}
if m.ddcBackend != nil {
m.ddcBackend.Close()
}
}

View File

@@ -0,0 +1,107 @@
package cups
import (
"strings"
"time"
"github.com/AvengeMedia/danklinux/pkg/ipp"
)
func (m *Manager) GetPrinters() ([]Printer, error) {
attributes := []string{
ipp.AttributePrinterName,
ipp.AttributePrinterUriSupported,
ipp.AttributePrinterState,
ipp.AttributePrinterStateReasons,
ipp.AttributePrinterLocation,
ipp.AttributePrinterInfo,
ipp.AttributePrinterMakeAndModel,
ipp.AttributePrinterIsAcceptingJobs,
}
printerAttrs, err := m.client.GetPrinters(attributes)
if err != nil {
return nil, err
}
printers := make([]Printer, 0, len(printerAttrs))
for _, attrs := range printerAttrs {
printer := Printer{
Name: getStringAttr(attrs, ipp.AttributePrinterName),
URI: getStringAttr(attrs, ipp.AttributePrinterUriSupported),
State: parsePrinterState(attrs),
StateReason: getStringAttr(attrs, ipp.AttributePrinterStateReasons),
Location: getStringAttr(attrs, ipp.AttributePrinterLocation),
Info: getStringAttr(attrs, ipp.AttributePrinterInfo),
MakeModel: getStringAttr(attrs, ipp.AttributePrinterMakeAndModel),
Accepting: getBoolAttr(attrs, ipp.AttributePrinterIsAcceptingJobs),
}
if printer.Name != "" {
printers = append(printers, printer)
}
}
return printers, nil
}
func (m *Manager) GetJobs(printerName string, whichJobs string) ([]Job, error) {
attributes := []string{
ipp.AttributeJobID,
ipp.AttributeJobName,
ipp.AttributeJobState,
ipp.AttributeJobPrinterURI,
ipp.AttributeJobOriginatingUserName,
ipp.AttributeJobKilobyteOctets,
"time-at-creation",
}
jobAttrs, err := m.client.GetJobs(printerName, "", whichJobs, false, 0, 0, attributes)
if err != nil {
return nil, err
}
jobs := make([]Job, 0, len(jobAttrs))
for _, attrs := range jobAttrs {
job := Job{
ID: getIntAttr(attrs, ipp.AttributeJobID),
Name: getStringAttr(attrs, ipp.AttributeJobName),
State: parseJobState(attrs),
User: getStringAttr(attrs, ipp.AttributeJobOriginatingUserName),
Size: getIntAttr(attrs, ipp.AttributeJobKilobyteOctets) * 1024,
}
if uri := getStringAttr(attrs, ipp.AttributeJobPrinterURI); uri != "" {
parts := strings.Split(uri, "/")
if len(parts) > 0 {
job.Printer = parts[len(parts)-1]
}
}
if ts := getIntAttr(attrs, "time-at-creation"); ts > 0 {
job.TimeCreated = time.Unix(int64(ts), 0)
}
if job.ID != 0 {
jobs = append(jobs, job)
}
}
return jobs, nil
}
func (m *Manager) CancelJob(jobID int) error {
return m.client.CancelJob(jobID, false)
}
func (m *Manager) PausePrinter(printerName string) error {
return m.client.PausePrinter(printerName)
}
func (m *Manager) ResumePrinter(printerName string) error {
return m.client.ResumePrinter(printerName)
}
func (m *Manager) PurgeJobs(printerName string) error {
return m.client.CancelAllJob(printerName, true)
}

View File

@@ -0,0 +1,285 @@
package cups
import (
"errors"
"testing"
"time"
mocks_cups "github.com/AvengeMedia/danklinux/internal/mocks/cups"
"github.com/AvengeMedia/danklinux/pkg/ipp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
func TestManager_GetPrinters(t *testing.T) {
tests := []struct {
name string
mockRet map[string]ipp.Attributes
mockErr error
want int
wantErr bool
}{
{
name: "success",
mockRet: map[string]ipp.Attributes{
"printer1": {
ipp.AttributePrinterName: []ipp.Attribute{{Value: "printer1"}},
ipp.AttributePrinterUriSupported: []ipp.Attribute{{Value: "ipp://localhost/printers/printer1"}},
ipp.AttributePrinterState: []ipp.Attribute{{Value: 3}},
ipp.AttributePrinterStateReasons: []ipp.Attribute{{Value: "none"}},
ipp.AttributePrinterLocation: []ipp.Attribute{{Value: "Office"}},
ipp.AttributePrinterInfo: []ipp.Attribute{{Value: "Test Printer"}},
ipp.AttributePrinterMakeAndModel: []ipp.Attribute{{Value: "Generic"}},
ipp.AttributePrinterIsAcceptingJobs: []ipp.Attribute{{Value: true}},
},
},
mockErr: nil,
want: 1,
wantErr: false,
},
{
name: "error",
mockRet: nil,
mockErr: errors.New("test error"),
want: 0,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockClient := mocks_cups.NewMockCUPSClientInterface(t)
mockClient.EXPECT().GetPrinters(mock.Anything).Return(tt.mockRet, tt.mockErr)
m := &Manager{
client: mockClient,
}
got, err := m.GetPrinters()
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.want, len(got))
if len(got) > 0 {
assert.Equal(t, "printer1", got[0].Name)
assert.Equal(t, "idle", got[0].State)
assert.Equal(t, "Office", got[0].Location)
assert.True(t, got[0].Accepting)
}
}
})
}
}
func TestManager_GetJobs(t *testing.T) {
tests := []struct {
name string
mockRet map[int]ipp.Attributes
mockErr error
want int
wantErr bool
}{
{
name: "success",
mockRet: map[int]ipp.Attributes{
1: {
ipp.AttributeJobID: []ipp.Attribute{{Value: 1}},
ipp.AttributeJobName: []ipp.Attribute{{Value: "test-job"}},
ipp.AttributeJobState: []ipp.Attribute{{Value: 5}},
ipp.AttributeJobPrinterURI: []ipp.Attribute{{Value: "ipp://localhost/printers/printer1"}},
ipp.AttributeJobOriginatingUserName: []ipp.Attribute{{Value: "testuser"}},
ipp.AttributeJobKilobyteOctets: []ipp.Attribute{{Value: 10}},
"time-at-creation": []ipp.Attribute{{Value: 1609459200}},
},
},
mockErr: nil,
want: 1,
wantErr: false,
},
{
name: "error",
mockRet: nil,
mockErr: errors.New("test error"),
want: 0,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockClient := mocks_cups.NewMockCUPSClientInterface(t)
mockClient.EXPECT().GetJobs("printer1", "", "not-completed", false, 0, 0, mock.Anything).
Return(tt.mockRet, tt.mockErr)
m := &Manager{
client: mockClient,
}
got, err := m.GetJobs("printer1", "not-completed")
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.want, len(got))
if len(got) > 0 {
assert.Equal(t, 1, got[0].ID)
assert.Equal(t, "test-job", got[0].Name)
assert.Equal(t, "processing", got[0].State)
assert.Equal(t, "testuser", got[0].User)
assert.Equal(t, "printer1", got[0].Printer)
assert.Equal(t, 10240, got[0].Size)
assert.Equal(t, time.Unix(1609459200, 0), got[0].TimeCreated)
}
}
})
}
}
func TestManager_CancelJob(t *testing.T) {
tests := []struct {
name string
mockErr error
wantErr bool
}{
{
name: "success",
mockErr: nil,
wantErr: false,
},
{
name: "error",
mockErr: errors.New("test error"),
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockClient := mocks_cups.NewMockCUPSClientInterface(t)
mockClient.EXPECT().CancelJob(1, false).Return(tt.mockErr)
m := &Manager{
client: mockClient,
}
err := m.CancelJob(1)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
func TestManager_PausePrinter(t *testing.T) {
tests := []struct {
name string
mockErr error
wantErr bool
}{
{
name: "success",
mockErr: nil,
wantErr: false,
},
{
name: "error",
mockErr: errors.New("test error"),
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockClient := mocks_cups.NewMockCUPSClientInterface(t)
mockClient.EXPECT().PausePrinter("printer1").Return(tt.mockErr)
m := &Manager{
client: mockClient,
}
err := m.PausePrinter("printer1")
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
func TestManager_ResumePrinter(t *testing.T) {
tests := []struct {
name string
mockErr error
wantErr bool
}{
{
name: "success",
mockErr: nil,
wantErr: false,
},
{
name: "error",
mockErr: errors.New("test error"),
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockClient := mocks_cups.NewMockCUPSClientInterface(t)
mockClient.EXPECT().ResumePrinter("printer1").Return(tt.mockErr)
m := &Manager{
client: mockClient,
}
err := m.ResumePrinter("printer1")
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
func TestManager_PurgeJobs(t *testing.T) {
tests := []struct {
name string
mockErr error
wantErr bool
}{
{
name: "success",
mockErr: nil,
wantErr: false,
},
{
name: "error",
mockErr: errors.New("test error"),
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockClient := mocks_cups.NewMockCUPSClientInterface(t)
mockClient.EXPECT().CancelAllJob("printer1", true).Return(tt.mockErr)
m := &Manager{
client: mockClient,
}
err := m.PurgeJobs("printer1")
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}

View File

@@ -0,0 +1,160 @@
package cups
import (
"encoding/json"
"fmt"
"net"
"github.com/AvengeMedia/danklinux/internal/server/models"
)
type Request struct {
ID int `json:"id,omitempty"`
Method string `json:"method"`
Params map[string]interface{} `json:"params,omitempty"`
}
type SuccessResult struct {
Success bool `json:"success"`
Message string `json:"message"`
}
type CUPSEvent struct {
Type string `json:"type"`
Data CUPSState `json:"data"`
}
func HandleRequest(conn net.Conn, req Request, manager *Manager) {
switch req.Method {
case "cups.subscribe":
handleSubscribe(conn, req, manager)
case "cups.getPrinters":
handleGetPrinters(conn, req, manager)
case "cups.getJobs":
handleGetJobs(conn, req, manager)
case "cups.pausePrinter":
handlePausePrinter(conn, req, manager)
case "cups.resumePrinter":
handleResumePrinter(conn, req, manager)
case "cups.cancelJob":
handleCancelJob(conn, req, manager)
case "cups.purgeJobs":
handlePurgeJobs(conn, req, manager)
default:
models.RespondError(conn, req.ID, fmt.Sprintf("unknown method: %s", req.Method))
}
}
func handleGetPrinters(conn net.Conn, req Request, manager *Manager) {
printers, err := manager.GetPrinters()
if err != nil {
models.RespondError(conn, req.ID, err.Error())
return
}
models.Respond(conn, req.ID, printers)
}
func handleGetJobs(conn net.Conn, req Request, manager *Manager) {
printerName, ok := req.Params["printerName"].(string)
if !ok {
models.RespondError(conn, req.ID, "missing or invalid 'printerName' parameter")
return
}
jobs, err := manager.GetJobs(printerName, "not-completed")
if err != nil {
models.RespondError(conn, req.ID, err.Error())
return
}
models.Respond(conn, req.ID, jobs)
}
func handlePausePrinter(conn net.Conn, req Request, manager *Manager) {
printerName, ok := req.Params["printerName"].(string)
if !ok {
models.RespondError(conn, req.ID, "missing or invalid 'printerName' parameter")
return
}
if err := manager.PausePrinter(printerName); err != nil {
models.RespondError(conn, req.ID, err.Error())
return
}
models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "paused"})
}
func handleResumePrinter(conn net.Conn, req Request, manager *Manager) {
printerName, ok := req.Params["printerName"].(string)
if !ok {
models.RespondError(conn, req.ID, "missing or invalid 'printerName' parameter")
return
}
if err := manager.ResumePrinter(printerName); err != nil {
models.RespondError(conn, req.ID, err.Error())
return
}
models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "resumed"})
}
func handleCancelJob(conn net.Conn, req Request, manager *Manager) {
jobIDFloat, ok := req.Params["jobID"].(float64)
if !ok {
models.RespondError(conn, req.ID, "missing or invalid 'jobid' parameter")
return
}
jobID := int(jobIDFloat)
if err := manager.CancelJob(jobID); err != nil {
models.RespondError(conn, req.ID, err.Error())
return
}
models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "job canceled"})
}
func handlePurgeJobs(conn net.Conn, req Request, manager *Manager) {
printerName, ok := req.Params["printerName"].(string)
if !ok {
models.RespondError(conn, req.ID, "missing or invalid 'printerName' parameter")
return
}
if err := manager.PurgeJobs(printerName); err != nil {
models.RespondError(conn, req.ID, err.Error())
return
}
models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "jobs canceled"})
}
func handleSubscribe(conn net.Conn, req Request, manager *Manager) {
clientID := fmt.Sprintf("client-%p", conn)
stateChan := manager.Subscribe(clientID)
defer manager.Unsubscribe(clientID)
initialState := manager.GetState()
event := CUPSEvent{
Type: "state_changed",
Data: initialState,
}
if err := json.NewEncoder(conn).Encode(models.Response[CUPSEvent]{
ID: req.ID,
Result: &event,
}); err != nil {
return
}
for state := range stateChan {
event := CUPSEvent{
Type: "state_changed",
Data: state,
}
if err := json.NewEncoder(conn).Encode(models.Response[CUPSEvent]{
Result: &event,
}); err != nil {
return
}
}
}

View File

@@ -0,0 +1,279 @@
package cups
import (
"bytes"
"encoding/json"
"errors"
"net"
"testing"
"time"
mocks_cups "github.com/AvengeMedia/danklinux/internal/mocks/cups"
"github.com/AvengeMedia/danklinux/internal/server/models"
"github.com/AvengeMedia/danklinux/pkg/ipp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
type mockConn struct {
*bytes.Buffer
}
func (m *mockConn) Close() error { return nil }
func (m *mockConn) LocalAddr() net.Addr { return nil }
func (m *mockConn) RemoteAddr() net.Addr { return nil }
func (m *mockConn) SetDeadline(t time.Time) error { return nil }
func (m *mockConn) SetReadDeadline(t time.Time) error { return nil }
func (m *mockConn) SetWriteDeadline(t time.Time) error { return nil }
func TestHandleGetPrinters(t *testing.T) {
mockClient := mocks_cups.NewMockCUPSClientInterface(t)
mockClient.EXPECT().GetPrinters(mock.Anything).Return(map[string]ipp.Attributes{
"printer1": {
ipp.AttributePrinterName: []ipp.Attribute{{Value: "printer1"}},
ipp.AttributePrinterState: []ipp.Attribute{{Value: 3}},
ipp.AttributePrinterUriSupported: []ipp.Attribute{{Value: "ipp://localhost/printers/printer1"}},
},
}, nil)
m := &Manager{
client: mockClient,
}
buf := &bytes.Buffer{}
conn := &mockConn{Buffer: buf}
req := Request{
ID: 1,
Method: "cups.getPrinters",
}
handleGetPrinters(conn, req, m)
var resp models.Response[[]Printer]
err := json.NewDecoder(buf).Decode(&resp)
assert.NoError(t, err)
assert.NotNil(t, resp.Result)
assert.Equal(t, 1, len(*resp.Result))
}
func TestHandleGetPrinters_Error(t *testing.T) {
mockClient := mocks_cups.NewMockCUPSClientInterface(t)
mockClient.EXPECT().GetPrinters(mock.Anything).Return(nil, errors.New("test error"))
m := &Manager{
client: mockClient,
}
buf := &bytes.Buffer{}
conn := &mockConn{Buffer: buf}
req := Request{
ID: 1,
Method: "cups.getPrinters",
}
handleGetPrinters(conn, req, m)
var resp models.Response[interface{}]
err := json.NewDecoder(buf).Decode(&resp)
assert.NoError(t, err)
assert.Nil(t, resp.Result)
assert.NotNil(t, resp.Error)
}
func TestHandleGetJobs(t *testing.T) {
mockClient := mocks_cups.NewMockCUPSClientInterface(t)
mockClient.EXPECT().GetJobs("printer1", "", "not-completed", false, 0, 0, mock.Anything).
Return(map[int]ipp.Attributes{
1: {
ipp.AttributeJobID: []ipp.Attribute{{Value: 1}},
ipp.AttributeJobName: []ipp.Attribute{{Value: "job1"}},
ipp.AttributeJobState: []ipp.Attribute{{Value: 5}},
},
}, nil)
m := &Manager{
client: mockClient,
}
buf := &bytes.Buffer{}
conn := &mockConn{Buffer: buf}
req := Request{
ID: 1,
Method: "cups.getJobs",
Params: map[string]interface{}{
"printerName": "printer1",
},
}
handleGetJobs(conn, req, m)
var resp models.Response[[]Job]
err := json.NewDecoder(buf).Decode(&resp)
assert.NoError(t, err)
assert.NotNil(t, resp.Result)
assert.Equal(t, 1, len(*resp.Result))
}
func TestHandleGetJobs_MissingParam(t *testing.T) {
mockClient := mocks_cups.NewMockCUPSClientInterface(t)
m := &Manager{
client: mockClient,
}
buf := &bytes.Buffer{}
conn := &mockConn{Buffer: buf}
req := Request{
ID: 1,
Method: "cups.getJobs",
Params: map[string]interface{}{},
}
handleGetJobs(conn, req, m)
var resp models.Response[interface{}]
err := json.NewDecoder(buf).Decode(&resp)
assert.NoError(t, err)
assert.Nil(t, resp.Result)
assert.NotNil(t, resp.Error)
}
func TestHandlePausePrinter(t *testing.T) {
mockClient := mocks_cups.NewMockCUPSClientInterface(t)
mockClient.EXPECT().PausePrinter("printer1").Return(nil)
m := &Manager{
client: mockClient,
}
buf := &bytes.Buffer{}
conn := &mockConn{Buffer: buf}
req := Request{
ID: 1,
Method: "cups.pausePrinter",
Params: map[string]interface{}{
"printerName": "printer1",
},
}
handlePausePrinter(conn, req, m)
var resp models.Response[SuccessResult]
err := json.NewDecoder(buf).Decode(&resp)
assert.NoError(t, err)
assert.NotNil(t, resp.Result)
assert.True(t, resp.Result.Success)
}
func TestHandleResumePrinter(t *testing.T) {
mockClient := mocks_cups.NewMockCUPSClientInterface(t)
mockClient.EXPECT().ResumePrinter("printer1").Return(nil)
m := &Manager{
client: mockClient,
}
buf := &bytes.Buffer{}
conn := &mockConn{Buffer: buf}
req := Request{
ID: 1,
Method: "cups.resumePrinter",
Params: map[string]interface{}{
"printerName": "printer1",
},
}
handleResumePrinter(conn, req, m)
var resp models.Response[SuccessResult]
err := json.NewDecoder(buf).Decode(&resp)
assert.NoError(t, err)
assert.NotNil(t, resp.Result)
assert.True(t, resp.Result.Success)
}
func TestHandleCancelJob(t *testing.T) {
mockClient := mocks_cups.NewMockCUPSClientInterface(t)
mockClient.EXPECT().CancelJob(1, false).Return(nil)
m := &Manager{
client: mockClient,
}
buf := &bytes.Buffer{}
conn := &mockConn{Buffer: buf}
req := Request{
ID: 1,
Method: "cups.cancelJob",
Params: map[string]interface{}{
"jobID": float64(1),
},
}
handleCancelJob(conn, req, m)
var resp models.Response[SuccessResult]
err := json.NewDecoder(buf).Decode(&resp)
assert.NoError(t, err)
assert.NotNil(t, resp.Result)
assert.True(t, resp.Result.Success)
}
func TestHandlePurgeJobs(t *testing.T) {
mockClient := mocks_cups.NewMockCUPSClientInterface(t)
mockClient.EXPECT().CancelAllJob("printer1", true).Return(nil)
m := &Manager{
client: mockClient,
}
buf := &bytes.Buffer{}
conn := &mockConn{Buffer: buf}
req := Request{
ID: 1,
Method: "cups.purgeJobs",
Params: map[string]interface{}{
"printerName": "printer1",
},
}
handlePurgeJobs(conn, req, m)
var resp models.Response[SuccessResult]
err := json.NewDecoder(buf).Decode(&resp)
assert.NoError(t, err)
assert.NotNil(t, resp.Result)
assert.True(t, resp.Result.Success)
}
func TestHandleRequest_UnknownMethod(t *testing.T) {
mockClient := mocks_cups.NewMockCUPSClientInterface(t)
m := &Manager{
client: mockClient,
}
buf := &bytes.Buffer{}
conn := &mockConn{Buffer: buf}
req := Request{
ID: 1,
Method: "cups.unknownMethod",
}
HandleRequest(conn, req, m)
var resp models.Response[interface{}]
err := json.NewDecoder(buf).Decode(&resp)
assert.NoError(t, err)
assert.Nil(t, resp.Result)
assert.NotNil(t, resp.Error)
}

View File

@@ -0,0 +1,340 @@
package cups
import (
"fmt"
"os"
"strconv"
"sync"
"time"
"github.com/AvengeMedia/danklinux/internal/log"
"github.com/AvengeMedia/danklinux/pkg/ipp"
)
func NewManager() (*Manager, error) {
host := os.Getenv("DMS_IPP_HOST")
if host == "" {
host = "localhost"
}
portStr := os.Getenv("DMS_IPP_PORT")
port := 631
if portStr != "" {
if p, err := strconv.Atoi(portStr); err == nil {
port = p
}
}
username := os.Getenv("DMS_IPP_USERNAME")
password := os.Getenv("DMS_IPP_PASSWORD")
client := ipp.NewCUPSClient(host, port, username, password, false)
baseURL := fmt.Sprintf("http://%s:%d", host, port)
m := &Manager{
state: &CUPSState{
Printers: make(map[string]*Printer),
},
client: client,
baseURL: baseURL,
stateMutex: sync.RWMutex{},
stopChan: make(chan struct{}),
dirty: make(chan struct{}, 1),
subscribers: make(map[string]chan CUPSState),
subMutex: sync.RWMutex{},
}
if err := m.updateState(); err != nil {
return nil, err
}
if isLocalCUPS(host) {
m.subscription = NewDBusSubscriptionManager(client, baseURL)
log.Infof("[CUPS] Using D-Bus notifications for local CUPS")
} else {
m.subscription = NewSubscriptionManager(client, baseURL)
log.Infof("[CUPS] Using IPPGET notifications for remote CUPS")
}
m.notifierWg.Add(1)
go m.notifier()
return m, nil
}
func isLocalCUPS(host string) bool {
switch host {
case "localhost", "127.0.0.1", "::1", "":
return true
}
return false
}
func (m *Manager) eventHandler() {
defer m.eventWG.Done()
if m.subscription == nil {
return
}
for {
select {
case <-m.stopChan:
return
case event, ok := <-m.subscription.Events():
if !ok {
return
}
log.Debugf("[CUPS] Received event: %s (printer: %s, job: %d)",
event.EventName, event.PrinterName, event.JobID)
if err := m.updateState(); err != nil {
log.Warnf("[CUPS] Failed to update state after event: %v", err)
} else {
m.notifySubscribers()
}
}
}
}
func (m *Manager) updateState() error {
printers, err := m.GetPrinters()
if err != nil {
return err
}
printerMap := make(map[string]*Printer, len(printers))
for _, printer := range printers {
jobs, err := m.GetJobs(printer.Name, "not-completed")
if err != nil {
return err
}
printer.Jobs = jobs
printerMap[printer.Name] = &printer
}
m.stateMutex.Lock()
m.state.Printers = printerMap
m.stateMutex.Unlock()
return nil
}
func (m *Manager) notifier() {
defer m.notifierWg.Done()
const minGap = 100 * time.Millisecond
timer := time.NewTimer(minGap)
timer.Stop()
var pending bool
for {
select {
case <-m.stopChan:
timer.Stop()
return
case <-m.dirty:
if pending {
continue
}
pending = true
timer.Reset(minGap)
case <-timer.C:
if !pending {
continue
}
m.subMutex.RLock()
if len(m.subscribers) == 0 {
m.subMutex.RUnlock()
pending = false
continue
}
currentState := m.snapshotState()
if m.lastNotifiedState != nil && !stateChanged(m.lastNotifiedState, &currentState) {
m.subMutex.RUnlock()
pending = false
continue
}
for _, ch := range m.subscribers {
select {
case ch <- currentState:
default:
}
}
m.subMutex.RUnlock()
stateCopy := currentState
m.lastNotifiedState = &stateCopy
pending = false
}
}
}
func (m *Manager) notifySubscribers() {
select {
case m.dirty <- struct{}{}:
default:
}
}
func (m *Manager) GetState() CUPSState {
return m.snapshotState()
}
func (m *Manager) snapshotState() CUPSState {
m.stateMutex.RLock()
defer m.stateMutex.RUnlock()
s := CUPSState{
Printers: make(map[string]*Printer, len(m.state.Printers)),
}
for name, printer := range m.state.Printers {
printerCopy := *printer
s.Printers[name] = &printerCopy
}
return s
}
func (m *Manager) Subscribe(id string) chan CUPSState {
ch := make(chan CUPSState, 64)
m.subMutex.Lock()
wasEmpty := len(m.subscribers) == 0
m.subscribers[id] = ch
m.subMutex.Unlock()
if wasEmpty && m.subscription != nil {
if err := m.subscription.Start(); err != nil {
log.Warnf("[CUPS] Failed to start subscription manager: %v", err)
} else {
m.eventWG.Add(1)
go m.eventHandler()
}
}
return ch
}
func (m *Manager) Unsubscribe(id string) {
m.subMutex.Lock()
if ch, ok := m.subscribers[id]; ok {
close(ch)
delete(m.subscribers, id)
}
isEmpty := len(m.subscribers) == 0
m.subMutex.Unlock()
if isEmpty && m.subscription != nil {
m.subscription.Stop()
m.eventWG.Wait()
}
}
func (m *Manager) Close() {
close(m.stopChan)
if m.subscription != nil {
m.subscription.Stop()
}
m.eventWG.Wait()
m.notifierWg.Wait()
m.subMutex.Lock()
for _, ch := range m.subscribers {
close(ch)
}
m.subscribers = make(map[string]chan CUPSState)
m.subMutex.Unlock()
}
func stateChanged(old, new *CUPSState) bool {
if len(old.Printers) != len(new.Printers) {
return true
}
for name, oldPrinter := range old.Printers {
newPrinter, exists := new.Printers[name]
if !exists {
return true
}
if oldPrinter.State != newPrinter.State ||
oldPrinter.StateReason != newPrinter.StateReason ||
len(oldPrinter.Jobs) != len(newPrinter.Jobs) {
return true
}
}
return false
}
func parsePrinterState(attrs ipp.Attributes) string {
if stateAttr, ok := attrs[ipp.AttributePrinterState]; ok && len(stateAttr) > 0 {
if state, ok := stateAttr[0].Value.(int); ok {
switch state {
case 3:
return "idle"
case 4:
return "processing"
case 5:
return "stopped"
default:
return fmt.Sprintf("%d", state)
}
}
}
return "unknown"
}
func parseJobState(attrs ipp.Attributes) string {
if stateAttr, ok := attrs[ipp.AttributeJobState]; ok && len(stateAttr) > 0 {
if state, ok := stateAttr[0].Value.(int); ok {
switch state {
case 3:
return "pending"
case 4:
return "pending-held"
case 5:
return "processing"
case 6:
return "processing-stopped"
case 7:
return "canceled"
case 8:
return "aborted"
case 9:
return "completed"
default:
return fmt.Sprintf("%d", state)
}
}
}
return "unknown"
}
func getStringAttr(attrs ipp.Attributes, key string) string {
if attr, ok := attrs[key]; ok && len(attr) > 0 {
if val, ok := attr[0].Value.(string); ok {
return val
}
return fmt.Sprintf("%v", attr[0].Value)
}
return ""
}
func getIntAttr(attrs ipp.Attributes, key string) int {
if attr, ok := attrs[key]; ok && len(attr) > 0 {
if val, ok := attr[0].Value.(int); ok {
return val
}
}
return 0
}
func getBoolAttr(attrs ipp.Attributes, key string) bool {
if attr, ok := attrs[key]; ok && len(attr) > 0 {
if val, ok := attr[0].Value.(bool); ok {
return val
}
}
return false
}

View File

@@ -0,0 +1,351 @@
package cups
import (
"testing"
mocks_cups "github.com/AvengeMedia/danklinux/internal/mocks/cups"
"github.com/AvengeMedia/danklinux/pkg/ipp"
"github.com/stretchr/testify/assert"
)
func TestNewManager(t *testing.T) {
m := &Manager{
state: &CUPSState{
Printers: make(map[string]*Printer),
},
client: nil,
stopChan: make(chan struct{}),
dirty: make(chan struct{}, 1),
subscribers: make(map[string]chan CUPSState),
}
assert.NotNil(t, m)
assert.NotNil(t, m.state)
}
func TestManager_GetState(t *testing.T) {
mockClient := mocks_cups.NewMockCUPSClientInterface(t)
m := &Manager{
state: &CUPSState{
Printers: map[string]*Printer{
"test-printer": {
Name: "test-printer",
State: "idle",
},
},
},
client: mockClient,
stopChan: make(chan struct{}),
dirty: make(chan struct{}, 1),
subscribers: make(map[string]chan CUPSState),
}
state := m.GetState()
assert.Equal(t, 1, len(state.Printers))
assert.Equal(t, "test-printer", state.Printers["test-printer"].Name)
}
func TestManager_Subscribe(t *testing.T) {
mockClient := mocks_cups.NewMockCUPSClientInterface(t)
m := &Manager{
state: &CUPSState{
Printers: make(map[string]*Printer),
},
client: mockClient,
stopChan: make(chan struct{}),
dirty: make(chan struct{}, 1),
subscribers: make(map[string]chan CUPSState),
}
ch := m.Subscribe("test-client")
assert.NotNil(t, ch)
assert.Equal(t, 1, len(m.subscribers))
m.Unsubscribe("test-client")
assert.Equal(t, 0, len(m.subscribers))
}
func TestManager_Close(t *testing.T) {
mockClient := mocks_cups.NewMockCUPSClientInterface(t)
m := &Manager{
state: &CUPSState{
Printers: make(map[string]*Printer),
},
client: mockClient,
stopChan: make(chan struct{}),
dirty: make(chan struct{}, 1),
subscribers: make(map[string]chan CUPSState),
}
m.eventWG.Add(1)
go func() {
defer m.eventWG.Done()
<-m.stopChan
}()
m.notifierWg.Add(1)
go func() {
defer m.notifierWg.Done()
<-m.stopChan
}()
m.Close()
assert.Equal(t, 0, len(m.subscribers))
}
func TestStateChanged(t *testing.T) {
tests := []struct {
name string
oldState *CUPSState
newState *CUPSState
want bool
}{
{
name: "no change",
oldState: &CUPSState{
Printers: map[string]*Printer{
"p1": {Name: "p1", State: "idle"},
},
},
newState: &CUPSState{
Printers: map[string]*Printer{
"p1": {Name: "p1", State: "idle"},
},
},
want: false,
},
{
name: "state changed",
oldState: &CUPSState{
Printers: map[string]*Printer{
"p1": {Name: "p1", State: "idle"},
},
},
newState: &CUPSState{
Printers: map[string]*Printer{
"p1": {Name: "p1", State: "processing"},
},
},
want: true,
},
{
name: "printer added",
oldState: &CUPSState{
Printers: map[string]*Printer{},
},
newState: &CUPSState{
Printers: map[string]*Printer{
"p1": {Name: "p1", State: "idle"},
},
},
want: true,
},
{
name: "printer removed",
oldState: &CUPSState{
Printers: map[string]*Printer{
"p1": {Name: "p1", State: "idle"},
},
},
newState: &CUPSState{
Printers: map[string]*Printer{},
},
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := stateChanged(tt.oldState, tt.newState)
assert.Equal(t, tt.want, got)
})
}
}
func TestParsePrinterState(t *testing.T) {
tests := []struct {
name string
attrs ipp.Attributes
want string
}{
{
name: "idle",
attrs: ipp.Attributes{
ipp.AttributePrinterState: []ipp.Attribute{{Value: 3}},
},
want: "idle",
},
{
name: "processing",
attrs: ipp.Attributes{
ipp.AttributePrinterState: []ipp.Attribute{{Value: 4}},
},
want: "processing",
},
{
name: "stopped",
attrs: ipp.Attributes{
ipp.AttributePrinterState: []ipp.Attribute{{Value: 5}},
},
want: "stopped",
},
{
name: "unknown",
attrs: ipp.Attributes{},
want: "unknown",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := parsePrinterState(tt.attrs)
assert.Equal(t, tt.want, got)
})
}
}
func TestParseJobState(t *testing.T) {
tests := []struct {
name string
attrs ipp.Attributes
want string
}{
{
name: "pending",
attrs: ipp.Attributes{
ipp.AttributeJobState: []ipp.Attribute{{Value: 3}},
},
want: "pending",
},
{
name: "processing",
attrs: ipp.Attributes{
ipp.AttributeJobState: []ipp.Attribute{{Value: 5}},
},
want: "processing",
},
{
name: "completed",
attrs: ipp.Attributes{
ipp.AttributeJobState: []ipp.Attribute{{Value: 9}},
},
want: "completed",
},
{
name: "unknown",
attrs: ipp.Attributes{},
want: "unknown",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := parseJobState(tt.attrs)
assert.Equal(t, tt.want, got)
})
}
}
func TestGetStringAttr(t *testing.T) {
tests := []struct {
name string
attrs ipp.Attributes
key string
want string
}{
{
name: "string value",
attrs: ipp.Attributes{
"test-key": []ipp.Attribute{{Value: "test-value"}},
},
key: "test-key",
want: "test-value",
},
{
name: "missing key",
attrs: ipp.Attributes{},
key: "missing",
want: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := getStringAttr(tt.attrs, tt.key)
assert.Equal(t, tt.want, got)
})
}
}
func TestGetIntAttr(t *testing.T) {
tests := []struct {
name string
attrs ipp.Attributes
key string
want int
}{
{
name: "int value",
attrs: ipp.Attributes{
"test-key": []ipp.Attribute{{Value: 42}},
},
key: "test-key",
want: 42,
},
{
name: "missing key",
attrs: ipp.Attributes{},
key: "missing",
want: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := getIntAttr(tt.attrs, tt.key)
assert.Equal(t, tt.want, got)
})
}
}
func TestGetBoolAttr(t *testing.T) {
tests := []struct {
name string
attrs ipp.Attributes
key string
want bool
}{
{
name: "true value",
attrs: ipp.Attributes{
"test-key": []ipp.Attribute{{Value: true}},
},
key: "test-key",
want: true,
},
{
name: "false value",
attrs: ipp.Attributes{
"test-key": []ipp.Attribute{{Value: false}},
},
key: "test-key",
want: false,
},
{
name: "missing key",
attrs: ipp.Attributes{},
key: "missing",
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := getBoolAttr(tt.attrs, tt.key)
assert.Equal(t, tt.want, got)
})
}
}

View File

@@ -0,0 +1,245 @@
package cups
import (
"fmt"
"sync"
"time"
"github.com/AvengeMedia/danklinux/internal/log"
"github.com/AvengeMedia/danklinux/pkg/ipp"
)
type SubscriptionManager struct {
client CUPSClientInterface
subscriptionID int
sequenceNumber int
eventChan chan SubscriptionEvent
stopChan chan struct{}
wg sync.WaitGroup
baseURL string
running bool
mu sync.Mutex
}
func NewSubscriptionManager(client CUPSClientInterface, baseURL string) *SubscriptionManager {
return &SubscriptionManager{
client: client,
eventChan: make(chan SubscriptionEvent, 100),
stopChan: make(chan struct{}),
baseURL: baseURL,
}
}
func (sm *SubscriptionManager) Start() error {
sm.mu.Lock()
if sm.running {
sm.mu.Unlock()
return fmt.Errorf("subscription manager already running")
}
sm.running = true
sm.mu.Unlock()
subID, err := sm.createSubscription()
if err != nil {
sm.mu.Lock()
sm.running = false
sm.mu.Unlock()
return fmt.Errorf("failed to create subscription: %w", err)
}
sm.subscriptionID = subID
log.Infof("[CUPS] Created IPP subscription with ID %d", subID)
sm.wg.Add(1)
go sm.notificationLoop()
return nil
}
func (sm *SubscriptionManager) createSubscription() (int, error) {
req := ipp.NewRequest(ipp.OperationCreatePrinterSubscriptions, 1)
req.OperationAttributes[ipp.AttributePrinterURI] = fmt.Sprintf("%s/", sm.baseURL)
req.OperationAttributes[ipp.AttributeRequestingUserName] = "dms"
// Subscription attributes go in SubscriptionAttributes (subscription-attributes-tag in IPP)
req.SubscriptionAttributes = map[string]interface{}{
"notify-events": []string{
"printer-state-changed",
"printer-added",
"printer-deleted",
"job-created",
"job-completed",
"job-state-changed",
},
"notify-pull-method": "ippget",
"notify-lease-duration": 0,
}
// Send to root IPP endpoint
resp, err := sm.client.SendRequest(fmt.Sprintf("%s/", sm.baseURL), req, nil)
if err != nil {
return 0, fmt.Errorf("SendRequest failed: %w", err)
}
// Check for IPP errors
if err := resp.CheckForErrors(); err != nil {
return 0, fmt.Errorf("IPP error: %w", err)
}
// Subscription ID comes back in SubscriptionAttributes
if len(resp.SubscriptionAttributes) > 0 {
if idAttr, ok := resp.SubscriptionAttributes[0]["notify-subscription-id"]; ok && len(idAttr) > 0 {
if val, ok := idAttr[0].Value.(int); ok {
return val, nil
}
}
}
return 0, fmt.Errorf("no subscription ID returned")
}
func (sm *SubscriptionManager) notificationLoop() {
defer sm.wg.Done()
backoff := 1 * time.Second
for {
select {
case <-sm.stopChan:
return
default:
}
gotAny, err := sm.fetchNotificationsWithWait()
if err != nil {
log.Warnf("[CUPS] Error fetching notifications: %v", err)
jitter := time.Duration(50+(time.Now().UnixNano()%200)) * time.Millisecond
sleepTime := backoff + jitter
if sleepTime > 30*time.Second {
sleepTime = 30 * time.Second
}
select {
case <-sm.stopChan:
return
case <-time.After(sleepTime):
}
if backoff < 30*time.Second {
backoff *= 2
}
continue
}
backoff = 1 * time.Second
if gotAny {
continue
}
select {
case <-sm.stopChan:
return
case <-time.After(2 * time.Second):
}
}
}
func (sm *SubscriptionManager) fetchNotificationsWithWait() (bool, error) {
req := ipp.NewRequest(ipp.OperationGetNotifications, 1)
req.OperationAttributes[ipp.AttributePrinterURI] = fmt.Sprintf("%s/", sm.baseURL)
req.OperationAttributes[ipp.AttributeRequestingUserName] = "dms"
req.OperationAttributes["notify-subscription-ids"] = sm.subscriptionID
if sm.sequenceNumber > 0 {
req.OperationAttributes["notify-sequence-numbers"] = sm.sequenceNumber
}
resp, err := sm.client.SendRequest(fmt.Sprintf("%s/", sm.baseURL), req, nil)
if err != nil {
return false, err
}
gotAny := false
for _, eventGroup := range resp.SubscriptionAttributes {
if seqAttr, ok := eventGroup["notify-sequence-number"]; ok && len(seqAttr) > 0 {
if seqNum, ok := seqAttr[0].Value.(int); ok {
sm.sequenceNumber = seqNum + 1
}
}
event := sm.parseEvent(eventGroup)
gotAny = true
select {
case sm.eventChan <- event:
case <-sm.stopChan:
return gotAny, nil
default:
log.Warn("[CUPS] Event channel full, dropping event")
}
}
return gotAny, nil
}
func (sm *SubscriptionManager) parseEvent(attrs ipp.Attributes) SubscriptionEvent {
event := SubscriptionEvent{
SubscribedAt: time.Now(),
}
if attr, ok := attrs["notify-subscribed-event"]; ok && len(attr) > 0 {
if val, ok := attr[0].Value.(string); ok {
event.EventName = val
}
}
if attr, ok := attrs["printer-name"]; ok && len(attr) > 0 {
if val, ok := attr[0].Value.(string); ok {
event.PrinterName = val
}
}
if attr, ok := attrs["notify-job-id"]; ok && len(attr) > 0 {
if val, ok := attr[0].Value.(int); ok {
event.JobID = val
}
}
return event
}
func (sm *SubscriptionManager) Events() <-chan SubscriptionEvent {
return sm.eventChan
}
func (sm *SubscriptionManager) Stop() {
sm.mu.Lock()
if !sm.running {
sm.mu.Unlock()
return
}
sm.running = false
sm.mu.Unlock()
close(sm.stopChan)
sm.wg.Wait()
if sm.subscriptionID != 0 {
sm.cancelSubscription()
sm.subscriptionID = 0
sm.sequenceNumber = 0
}
sm.stopChan = make(chan struct{})
}
func (sm *SubscriptionManager) cancelSubscription() {
req := ipp.NewRequest(ipp.OperationCancelSubscription, 1)
req.OperationAttributes[ipp.AttributePrinterURI] = fmt.Sprintf("%s/", sm.baseURL)
req.OperationAttributes[ipp.AttributeRequestingUserName] = "dms"
req.OperationAttributes["notify-subscription-id"] = sm.subscriptionID
_, err := sm.client.SendRequest(fmt.Sprintf("%s/", sm.baseURL), req, nil)
if err != nil {
log.Warnf("[CUPS] Failed to cancel subscription %d: %v", sm.subscriptionID, err)
} else {
log.Infof("[CUPS] Cancelled subscription %d", sm.subscriptionID)
}
}

View File

@@ -0,0 +1,295 @@
package cups
import (
"fmt"
"strings"
"sync"
"github.com/AvengeMedia/danklinux/internal/log"
"github.com/AvengeMedia/danklinux/pkg/ipp"
"github.com/godbus/dbus/v5"
)
type DBusSubscriptionManager struct {
client CUPSClientInterface
subscriptionID int
eventChan chan SubscriptionEvent
stopChan chan struct{}
wg sync.WaitGroup
baseURL string
running bool
mu sync.Mutex
conn *dbus.Conn
}
func NewDBusSubscriptionManager(client CUPSClientInterface, baseURL string) *DBusSubscriptionManager {
return &DBusSubscriptionManager{
client: client,
eventChan: make(chan SubscriptionEvent, 100),
stopChan: make(chan struct{}),
baseURL: baseURL,
}
}
func (sm *DBusSubscriptionManager) Start() error {
sm.mu.Lock()
if sm.running {
sm.mu.Unlock()
return fmt.Errorf("subscription manager already running")
}
sm.running = true
sm.mu.Unlock()
conn, err := dbus.ConnectSystemBus()
if err != nil {
sm.mu.Lock()
sm.running = false
sm.mu.Unlock()
return fmt.Errorf("connect to system bus: %w", err)
}
sm.conn = conn
subID, err := sm.createDBusSubscription()
if err != nil {
sm.conn.Close()
sm.mu.Lock()
sm.running = false
sm.mu.Unlock()
return fmt.Errorf("failed to create D-Bus subscription: %w", err)
}
sm.subscriptionID = subID
log.Infof("[CUPS] Created D-Bus subscription with ID %d", subID)
if err := sm.conn.AddMatchSignal(
dbus.WithMatchInterface("org.cups.cupsd.Notifier"),
); err != nil {
sm.cancelSubscription()
sm.conn.Close()
sm.mu.Lock()
sm.running = false
sm.mu.Unlock()
return fmt.Errorf("failed to add D-Bus match: %w", err)
}
sm.wg.Add(1)
go sm.dbusListenerLoop()
return nil
}
func (sm *DBusSubscriptionManager) createDBusSubscription() (int, error) {
req := ipp.NewRequest(ipp.OperationCreatePrinterSubscriptions, 2)
req.OperationAttributes[ipp.AttributePrinterURI] = fmt.Sprintf("%s/", sm.baseURL)
req.OperationAttributes[ipp.AttributeRequestingUserName] = "dms"
req.SubscriptionAttributes = map[string]interface{}{
"notify-events": []string{
"printer-state-changed",
"printer-added",
"printer-deleted",
"job-created",
"job-completed",
"job-state-changed",
},
"notify-recipient-uri": "dbus:/",
"notify-lease-duration": 86400,
}
resp, err := sm.client.SendRequest(fmt.Sprintf("%s/", sm.baseURL), req, nil)
if err != nil {
return 0, fmt.Errorf("SendRequest failed: %w", err)
}
if err := resp.CheckForErrors(); err != nil {
return 0, fmt.Errorf("IPP error: %w", err)
}
if len(resp.SubscriptionAttributes) > 0 {
if idAttr, ok := resp.SubscriptionAttributes[0]["notify-subscription-id"]; ok && len(idAttr) > 0 {
if val, ok := idAttr[0].Value.(int); ok {
return val, nil
}
}
}
return 0, fmt.Errorf("no subscription ID returned")
}
func (sm *DBusSubscriptionManager) dbusListenerLoop() {
defer sm.wg.Done()
signalChan := make(chan *dbus.Signal, 10)
sm.conn.Signal(signalChan)
defer sm.conn.RemoveSignal(signalChan)
for {
select {
case <-sm.stopChan:
return
case sig := <-signalChan:
if sig == nil {
continue
}
event := sm.parseDBusSignal(sig)
if event.EventName == "" {
continue
}
select {
case sm.eventChan <- event:
case <-sm.stopChan:
return
default:
log.Warn("[CUPS] Event channel full, dropping event")
}
}
}
}
func (sm *DBusSubscriptionManager) parseDBusSignal(sig *dbus.Signal) SubscriptionEvent {
event := SubscriptionEvent{}
switch sig.Name {
case "org.cups.cupsd.Notifier.JobStateChanged":
if len(sig.Body) >= 6 {
if text, ok := sig.Body[0].(string); ok {
event.EventName = "job-state-changed"
parts := strings.Split(text, " ")
if len(parts) >= 2 {
event.PrinterName = parts[0]
}
}
if printerURI, ok := sig.Body[1].(string); ok && event.PrinterName == "" {
if idx := strings.LastIndex(printerURI, "/"); idx != -1 {
event.PrinterName = printerURI[idx+1:]
}
}
if jobID, ok := sig.Body[3].(uint32); ok {
event.JobID = int(jobID)
}
}
case "org.cups.cupsd.Notifier.JobCreated":
if len(sig.Body) >= 6 {
if text, ok := sig.Body[0].(string); ok {
event.EventName = "job-created"
parts := strings.Split(text, " ")
if len(parts) >= 2 {
event.PrinterName = parts[0]
}
}
if printerURI, ok := sig.Body[1].(string); ok && event.PrinterName == "" {
if idx := strings.LastIndex(printerURI, "/"); idx != -1 {
event.PrinterName = printerURI[idx+1:]
}
}
if jobID, ok := sig.Body[3].(uint32); ok {
event.JobID = int(jobID)
}
}
case "org.cups.cupsd.Notifier.JobCompleted":
if len(sig.Body) >= 6 {
if text, ok := sig.Body[0].(string); ok {
event.EventName = "job-completed"
parts := strings.Split(text, " ")
if len(parts) >= 2 {
event.PrinterName = parts[0]
}
}
if printerURI, ok := sig.Body[1].(string); ok && event.PrinterName == "" {
if idx := strings.LastIndex(printerURI, "/"); idx != -1 {
event.PrinterName = printerURI[idx+1:]
}
}
if jobID, ok := sig.Body[3].(uint32); ok {
event.JobID = int(jobID)
}
}
case "org.cups.cupsd.Notifier.PrinterStateChanged":
if len(sig.Body) >= 6 {
if text, ok := sig.Body[0].(string); ok {
event.EventName = "printer-state-changed"
parts := strings.Split(text, " ")
if len(parts) >= 2 {
event.PrinterName = parts[0]
}
}
if printerURI, ok := sig.Body[1].(string); ok && event.PrinterName == "" {
if idx := strings.LastIndex(printerURI, "/"); idx != -1 {
event.PrinterName = printerURI[idx+1:]
}
}
}
case "org.cups.cupsd.Notifier.PrinterAdded":
if len(sig.Body) >= 6 {
if text, ok := sig.Body[0].(string); ok {
event.EventName = "printer-added"
parts := strings.Split(text, " ")
if len(parts) >= 2 {
event.PrinterName = parts[0]
}
}
}
case "org.cups.cupsd.Notifier.PrinterDeleted":
if len(sig.Body) >= 6 {
if text, ok := sig.Body[0].(string); ok {
event.EventName = "printer-deleted"
parts := strings.Split(text, " ")
if len(parts) >= 2 {
event.PrinterName = parts[0]
}
}
}
}
return event
}
func (sm *DBusSubscriptionManager) Events() <-chan SubscriptionEvent {
return sm.eventChan
}
func (sm *DBusSubscriptionManager) Stop() {
sm.mu.Lock()
if !sm.running {
sm.mu.Unlock()
return
}
sm.running = false
sm.mu.Unlock()
close(sm.stopChan)
sm.wg.Wait()
if sm.subscriptionID != 0 {
sm.cancelSubscription()
sm.subscriptionID = 0
}
if sm.conn != nil {
sm.conn.Close()
sm.conn = nil
}
sm.stopChan = make(chan struct{})
}
func (sm *DBusSubscriptionManager) cancelSubscription() {
req := ipp.NewRequest(ipp.OperationCancelSubscription, 1)
req.OperationAttributes[ipp.AttributePrinterURI] = fmt.Sprintf("%s/", sm.baseURL)
req.OperationAttributes[ipp.AttributeRequestingUserName] = "dms"
req.OperationAttributes["notify-subscription-id"] = sm.subscriptionID
_, err := sm.client.SendRequest(fmt.Sprintf("%s/", sm.baseURL), req, nil)
if err != nil {
log.Warnf("[CUPS] Failed to cancel subscription %d: %v", sm.subscriptionID, err)
} else {
log.Infof("[CUPS] Cancelled subscription %d", sm.subscriptionID)
}
}

View File

@@ -0,0 +1,73 @@
package cups
import (
"io"
"sync"
"time"
"github.com/AvengeMedia/danklinux/pkg/ipp"
)
type CUPSState struct {
Printers map[string]*Printer `json:"printers"`
}
type Printer struct {
Name string `json:"name"`
URI string `json:"uri"`
State string `json:"state"`
StateReason string `json:"stateReason"`
Location string `json:"location"`
Info string `json:"info"`
MakeModel string `json:"makeModel"`
Accepting bool `json:"accepting"`
Jobs []Job `json:"jobs"`
}
type Job struct {
ID int `json:"id"`
Name string `json:"name"`
State string `json:"state"`
Printer string `json:"printer"`
User string `json:"user"`
Size int `json:"size"`
TimeCreated time.Time `json:"timeCreated"`
}
type Manager struct {
state *CUPSState
client CUPSClientInterface
subscription SubscriptionManagerInterface
stateMutex sync.RWMutex
subscribers map[string]chan CUPSState
subMutex sync.RWMutex
stopChan chan struct{}
eventWG sync.WaitGroup
dirty chan struct{}
notifierWg sync.WaitGroup
lastNotifiedState *CUPSState
baseURL string
}
type SubscriptionManagerInterface interface {
Start() error
Stop()
Events() <-chan SubscriptionEvent
}
type CUPSClientInterface interface {
GetPrinters(attributes []string) (map[string]ipp.Attributes, error)
GetJobs(printer, class string, whichJobs string, myJobs bool, firstJobId, limit int, attributes []string) (map[int]ipp.Attributes, error)
CancelJob(jobID int, purge bool) error
PausePrinter(printer string) error
ResumePrinter(printer string) error
CancelAllJob(printer string, purge bool) error
SendRequest(url string, req *ipp.Request, additionalResponseData io.Writer) (*ipp.Response, error)
}
type SubscriptionEvent struct {
EventName string
PrinterName string
JobID int
SubscribedAt time.Time
}

View File

@@ -0,0 +1,144 @@
package dwl
import (
"encoding/json"
"fmt"
"net"
"github.com/AvengeMedia/danklinux/internal/server/models"
)
type Request struct {
ID int `json:"id,omitempty"`
Method string `json:"method"`
Params map[string]interface{} `json:"params,omitempty"`
}
type SuccessResult struct {
Success bool `json:"success"`
Message string `json:"message"`
}
func HandleRequest(conn net.Conn, req Request, manager *Manager) {
if manager == nil {
models.RespondError(conn, req.ID, "dwl manager not initialized")
return
}
switch req.Method {
case "dwl.getState":
handleGetState(conn, req, manager)
case "dwl.setTags":
handleSetTags(conn, req, manager)
case "dwl.setClientTags":
handleSetClientTags(conn, req, manager)
case "dwl.setLayout":
handleSetLayout(conn, req, manager)
case "dwl.subscribe":
handleSubscribe(conn, req, manager)
default:
models.RespondError(conn, req.ID, fmt.Sprintf("unknown method: %s", req.Method))
}
}
func handleGetState(conn net.Conn, req Request, manager *Manager) {
state := manager.GetState()
models.Respond(conn, req.ID, state)
}
func handleSetTags(conn net.Conn, req Request, manager *Manager) {
output, ok := req.Params["output"].(string)
if !ok {
models.RespondError(conn, req.ID, "missing or invalid 'output' parameter")
return
}
tagmask, ok := req.Params["tagmask"].(float64)
if !ok {
models.RespondError(conn, req.ID, "missing or invalid 'tagmask' parameter")
return
}
toggleTagset, ok := req.Params["toggleTagset"].(float64)
if !ok {
models.RespondError(conn, req.ID, "missing or invalid 'toggleTagset' parameter")
return
}
if err := manager.SetTags(output, uint32(tagmask), uint32(toggleTagset)); err != nil {
models.RespondError(conn, req.ID, err.Error())
return
}
models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "tags set"})
}
func handleSetClientTags(conn net.Conn, req Request, manager *Manager) {
output, ok := req.Params["output"].(string)
if !ok {
models.RespondError(conn, req.ID, "missing or invalid 'output' parameter")
return
}
andTags, ok := req.Params["andTags"].(float64)
if !ok {
models.RespondError(conn, req.ID, "missing or invalid 'andTags' parameter")
return
}
xorTags, ok := req.Params["xorTags"].(float64)
if !ok {
models.RespondError(conn, req.ID, "missing or invalid 'xorTags' parameter")
return
}
if err := manager.SetClientTags(output, uint32(andTags), uint32(xorTags)); err != nil {
models.RespondError(conn, req.ID, err.Error())
return
}
models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "client tags set"})
}
func handleSetLayout(conn net.Conn, req Request, manager *Manager) {
output, ok := req.Params["output"].(string)
if !ok {
models.RespondError(conn, req.ID, "missing or invalid 'output' parameter")
return
}
index, ok := req.Params["index"].(float64)
if !ok {
models.RespondError(conn, req.ID, "missing or invalid 'index' parameter")
return
}
if err := manager.SetLayout(output, uint32(index)); err != nil {
models.RespondError(conn, req.ID, err.Error())
return
}
models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "layout set"})
}
func handleSubscribe(conn net.Conn, req Request, manager *Manager) {
clientID := fmt.Sprintf("client-%p", conn)
stateChan := manager.Subscribe(clientID)
defer manager.Unsubscribe(clientID)
initialState := manager.GetState()
if err := json.NewEncoder(conn).Encode(models.Response[State]{
ID: req.ID,
Result: &initialState,
}); err != nil {
return
}
for state := range stateChan {
if err := json.NewEncoder(conn).Encode(models.Response[State]{
Result: &state,
}); err != nil {
return
}
}
}

View File

@@ -0,0 +1,539 @@
package dwl
import (
"fmt"
"time"
wlclient "github.com/yaslama/go-wayland/wayland/client"
"github.com/AvengeMedia/danklinux/internal/log"
"github.com/AvengeMedia/danklinux/internal/proto/dwl_ipc"
)
func NewManager(display *wlclient.Display) (*Manager, error) {
m := &Manager{
display: display,
outputs: make(map[uint32]*outputState),
cmdq: make(chan cmd, 128),
outputSetupReq: make(chan uint32, 16),
stopChan: make(chan struct{}),
subscribers: make(map[string]chan State),
dirty: make(chan struct{}, 1),
layouts: make([]string, 0),
}
if err := m.setupRegistry(); err != nil {
return nil, err
}
m.updateState()
m.notifierWg.Add(1)
go m.notifier()
m.wg.Add(1)
go m.waylandActor()
return m, nil
}
func (m *Manager) post(fn func()) {
select {
case m.cmdq <- cmd{fn: fn}:
default:
log.Warn("DWL actor command queue full, dropping command")
}
}
func (m *Manager) waylandActor() {
defer m.wg.Done()
for {
select {
case <-m.stopChan:
return
case c := <-m.cmdq:
c.fn()
case outputID := <-m.outputSetupReq:
m.outputsMutex.RLock()
out, exists := m.outputs[outputID]
m.outputsMutex.RUnlock()
if !exists {
log.Warnf("DWL: Output %d no longer exists, skipping setup", outputID)
continue
}
if out.ipcOutput != nil {
continue
}
mgr, ok := m.manager.(*dwl_ipc.ZdwlIpcManagerV2)
if !ok || mgr == nil {
log.Errorf("DWL: Manager not available for output %d setup", outputID)
continue
}
log.Infof("DWL: Setting up ipcOutput for dynamically added output %d", outputID)
if err := m.setupOutput(mgr, out.output); err != nil {
log.Errorf("DWL: Failed to setup output %d: %v", outputID, err)
} else {
m.updateState()
}
}
}
}
func (m *Manager) setupRegistry() error {
log.Info("DWL: starting registry setup")
ctx := m.display.Context()
registry, err := m.display.GetRegistry()
if err != nil {
return fmt.Errorf("failed to get registry: %w", err)
}
m.registry = registry
outputs := make([]*wlclient.Output, 0)
outputRegNames := make(map[uint32]uint32)
var dwlMgr *dwl_ipc.ZdwlIpcManagerV2
registry.SetGlobalHandler(func(e wlclient.RegistryGlobalEvent) {
switch e.Interface {
case dwl_ipc.ZdwlIpcManagerV2InterfaceName:
log.Infof("DWL: found %s", dwl_ipc.ZdwlIpcManagerV2InterfaceName)
manager := dwl_ipc.NewZdwlIpcManagerV2(ctx)
version := e.Version
if version > 1 {
version = 1
}
if err := registry.Bind(e.Name, e.Interface, version, manager); err == nil {
dwlMgr = manager
log.Info("DWL: manager bound successfully")
} else {
log.Errorf("DWL: failed to bind manager: %v", err)
}
case "wl_output":
log.Debugf("DWL: found wl_output (name=%d)", e.Name)
output := wlclient.NewOutput(ctx)
outState := &outputState{
registryName: e.Name,
output: output,
tags: make([]TagState, 0),
}
output.SetNameHandler(func(ev wlclient.OutputNameEvent) {
log.Debugf("DWL: Output name: %s (registry=%d)", ev.Name, e.Name)
outState.name = ev.Name
})
output.SetDescriptionHandler(func(ev wlclient.OutputDescriptionEvent) {
log.Debugf("DWL: Output description: %s", ev.Description)
})
version := e.Version
if version > 4 {
version = 4
}
if err := registry.Bind(e.Name, e.Interface, version, output); err == nil {
outputID := output.ID()
outState.id = outputID
log.Infof("DWL: Bound wl_output id=%d registry_name=%d", outputID, e.Name)
outputs = append(outputs, output)
outputRegNames[outputID] = e.Name
m.outputsMutex.Lock()
m.outputs[outputID] = outState
m.outputsMutex.Unlock()
if m.manager != nil {
select {
case m.outputSetupReq <- outputID:
log.Debugf("DWL: Queued setup for output %d", outputID)
default:
log.Warnf("DWL: Setup queue full, output %d will not be initialized", outputID)
}
}
} else {
log.Errorf("DWL: Failed to bind wl_output: %v", err)
}
}
})
registry.SetGlobalRemoveHandler(func(e wlclient.RegistryGlobalRemoveEvent) {
m.post(func() {
m.outputsMutex.Lock()
var outToRelease *outputState
for id, out := range m.outputs {
if out.registryName == e.Name {
log.Infof("DWL: Output %d removed", id)
outToRelease = out
delete(m.outputs, id)
break
}
}
m.outputsMutex.Unlock()
if outToRelease != nil {
if ipcOut, ok := outToRelease.ipcOutput.(*dwl_ipc.ZdwlIpcOutputV2); ok && ipcOut != nil {
m.wlMutex.Lock()
ipcOut.Release()
m.wlMutex.Unlock()
log.Debugf("DWL: Released ipcOutput for removed output %d", outToRelease.id)
}
m.updateState()
}
})
})
if err := m.display.Roundtrip(); err != nil {
return fmt.Errorf("first roundtrip failed: %w", err)
}
if err := m.display.Roundtrip(); err != nil {
return fmt.Errorf("second roundtrip failed: %w", err)
}
if dwlMgr == nil {
log.Info("DWL: manager not found in registry")
return fmt.Errorf("dwl_ipc_manager_v2 not available")
}
dwlMgr.SetTagsHandler(func(e dwl_ipc.ZdwlIpcManagerV2TagsEvent) {
log.Infof("DWL: Tags count: %d", e.Amount)
m.tagCount = e.Amount
m.updateState()
})
dwlMgr.SetLayoutHandler(func(e dwl_ipc.ZdwlIpcManagerV2LayoutEvent) {
log.Infof("DWL: Layout: %s", e.Name)
m.layouts = append(m.layouts, e.Name)
m.updateState()
})
m.manager = dwlMgr
for _, output := range outputs {
if err := m.setupOutput(dwlMgr, output); err != nil {
log.Warnf("DWL: Failed to setup output %d: %v", output.ID(), err)
}
}
if err := m.display.Roundtrip(); err != nil {
return fmt.Errorf("final roundtrip failed: %w", err)
}
log.Info("DWL: registry setup complete")
return nil
}
func (m *Manager) setupOutput(manager *dwl_ipc.ZdwlIpcManagerV2, output *wlclient.Output) error {
m.wlMutex.Lock()
ipcOutput, err := manager.GetOutput(output)
m.wlMutex.Unlock()
if err != nil {
return fmt.Errorf("failed to get dwl output: %w", err)
}
m.outputsMutex.Lock()
outState, exists := m.outputs[output.ID()]
if !exists {
m.outputsMutex.Unlock()
return fmt.Errorf("output state not found for id %d", output.ID())
}
outState.ipcOutput = ipcOutput
m.outputsMutex.Unlock()
ipcOutput.SetActiveHandler(func(e dwl_ipc.ZdwlIpcOutputV2ActiveEvent) {
outState.active = e.Active
})
ipcOutput.SetTagHandler(func(e dwl_ipc.ZdwlIpcOutputV2TagEvent) {
updated := false
for i, tag := range outState.tags {
if tag.Tag == e.Tag {
outState.tags[i] = TagState{
Tag: e.Tag,
State: e.State,
Clients: e.Clients,
Focused: e.Focused,
}
updated = true
break
}
}
if !updated {
outState.tags = append(outState.tags, TagState{
Tag: e.Tag,
State: e.State,
Clients: e.Clients,
Focused: e.Focused,
})
}
m.updateState()
})
ipcOutput.SetLayoutHandler(func(e dwl_ipc.ZdwlIpcOutputV2LayoutEvent) {
outState.layout = e.Layout
})
ipcOutput.SetTitleHandler(func(e dwl_ipc.ZdwlIpcOutputV2TitleEvent) {
outState.title = e.Title
})
ipcOutput.SetAppidHandler(func(e dwl_ipc.ZdwlIpcOutputV2AppidEvent) {
outState.appID = e.Appid
})
ipcOutput.SetLayoutSymbolHandler(func(e dwl_ipc.ZdwlIpcOutputV2LayoutSymbolEvent) {
outState.layoutSymbol = e.Layout
})
ipcOutput.SetFrameHandler(func(e dwl_ipc.ZdwlIpcOutputV2FrameEvent) {
m.updateState()
})
return nil
}
func (m *Manager) updateState() {
m.outputsMutex.RLock()
outputs := make(map[string]*OutputState)
activeOutput := ""
for _, out := range m.outputs {
name := out.name
if name == "" {
name = fmt.Sprintf("output-%d", out.id)
}
tagsCopy := make([]TagState, len(out.tags))
copy(tagsCopy, out.tags)
outputs[name] = &OutputState{
Name: name,
Active: out.active,
Tags: tagsCopy,
Layout: out.layout,
LayoutSymbol: out.layoutSymbol,
Title: out.title,
AppID: out.appID,
}
if out.active != 0 {
activeOutput = name
}
}
m.outputsMutex.RUnlock()
newState := State{
Outputs: outputs,
TagCount: m.tagCount,
Layouts: m.layouts,
ActiveOutput: activeOutput,
}
m.stateMutex.Lock()
m.state = &newState
m.stateMutex.Unlock()
m.notifySubscribers()
}
func (m *Manager) notifier() {
defer m.notifierWg.Done()
const minGap = 100 * time.Millisecond
timer := time.NewTimer(minGap)
timer.Stop()
var pending bool
for {
select {
case <-m.stopChan:
timer.Stop()
return
case <-m.dirty:
if pending {
continue
}
pending = true
timer.Reset(minGap)
case <-timer.C:
if !pending {
continue
}
m.subMutex.RLock()
subCount := len(m.subscribers)
m.subMutex.RUnlock()
if subCount == 0 {
pending = false
continue
}
currentState := m.GetState()
if m.lastNotified != nil && !stateChanged(m.lastNotified, &currentState) {
pending = false
continue
}
m.subMutex.RLock()
for _, ch := range m.subscribers {
select {
case ch <- currentState:
default:
log.Warn("DWL: subscriber channel full, dropping update")
}
}
m.subMutex.RUnlock()
stateCopy := currentState
m.lastNotified = &stateCopy
pending = false
}
}
}
func (m *Manager) ensureOutputSetup(out *outputState) error {
if out.ipcOutput != nil {
return nil
}
return fmt.Errorf("output not yet initialized - setup in progress, retry in a moment")
}
func (m *Manager) SetTags(outputName string, tagmask uint32, toggleTagset uint32) error {
m.outputsMutex.RLock()
availableOutputs := make([]string, 0, len(m.outputs))
var targetOut *outputState
for _, out := range m.outputs {
name := out.name
if name == "" {
name = fmt.Sprintf("output-%d", out.id)
}
availableOutputs = append(availableOutputs, name)
if name == outputName {
targetOut = out
break
}
}
m.outputsMutex.RUnlock()
if targetOut == nil {
return fmt.Errorf("output not found: %s (available: %v)", outputName, availableOutputs)
}
if err := m.ensureOutputSetup(targetOut); err != nil {
return fmt.Errorf("failed to setup output %s: %w", outputName, err)
}
ipcOut, ok := targetOut.ipcOutput.(*dwl_ipc.ZdwlIpcOutputV2)
if !ok {
return fmt.Errorf("output %s has invalid ipcOutput type", outputName)
}
m.wlMutex.Lock()
err := ipcOut.SetTags(tagmask, toggleTagset)
m.wlMutex.Unlock()
return err
}
func (m *Manager) SetClientTags(outputName string, andTags uint32, xorTags uint32) error {
m.outputsMutex.RLock()
var targetOut *outputState
for _, out := range m.outputs {
name := out.name
if name == "" {
name = fmt.Sprintf("output-%d", out.id)
}
if name == outputName {
targetOut = out
break
}
}
m.outputsMutex.RUnlock()
if targetOut == nil {
return fmt.Errorf("output not found: %s", outputName)
}
if err := m.ensureOutputSetup(targetOut); err != nil {
return fmt.Errorf("failed to setup output %s: %w", outputName, err)
}
ipcOut, ok := targetOut.ipcOutput.(*dwl_ipc.ZdwlIpcOutputV2)
if !ok {
return fmt.Errorf("output %s has invalid ipcOutput type", outputName)
}
m.wlMutex.Lock()
err := ipcOut.SetClientTags(andTags, xorTags)
m.wlMutex.Unlock()
return err
}
func (m *Manager) SetLayout(outputName string, index uint32) error {
m.outputsMutex.RLock()
var targetOut *outputState
for _, out := range m.outputs {
name := out.name
if name == "" {
name = fmt.Sprintf("output-%d", out.id)
}
if name == outputName {
targetOut = out
break
}
}
m.outputsMutex.RUnlock()
if targetOut == nil {
return fmt.Errorf("output not found: %s", outputName)
}
if err := m.ensureOutputSetup(targetOut); err != nil {
return fmt.Errorf("failed to setup output %s: %w", outputName, err)
}
ipcOut, ok := targetOut.ipcOutput.(*dwl_ipc.ZdwlIpcOutputV2)
if !ok {
return fmt.Errorf("output %s has invalid ipcOutput type", outputName)
}
m.wlMutex.Lock()
err := ipcOut.SetLayout(index)
m.wlMutex.Unlock()
return err
}
func (m *Manager) Close() {
close(m.stopChan)
m.wg.Wait()
m.notifierWg.Wait()
m.subMutex.Lock()
for _, ch := range m.subscribers {
close(ch)
}
m.subscribers = make(map[string]chan State)
m.subMutex.Unlock()
m.outputsMutex.Lock()
for _, out := range m.outputs {
if ipcOut, ok := out.ipcOutput.(*dwl_ipc.ZdwlIpcOutputV2); ok {
ipcOut.Release()
}
}
m.outputs = make(map[uint32]*outputState)
m.outputsMutex.Unlock()
if mgr, ok := m.manager.(*dwl_ipc.ZdwlIpcManagerV2); ok {
mgr.Release()
}
}

View File

@@ -0,0 +1,169 @@
package dwl
import (
"sync"
wlclient "github.com/yaslama/go-wayland/wayland/client"
)
type TagState struct {
Tag uint32 `json:"tag"`
State uint32 `json:"state"`
Clients uint32 `json:"clients"`
Focused uint32 `json:"focused"`
}
type OutputState struct {
Name string `json:"name"`
Active uint32 `json:"active"`
Tags []TagState `json:"tags"`
Layout uint32 `json:"layout"`
LayoutSymbol string `json:"layoutSymbol"`
Title string `json:"title"`
AppID string `json:"appId"`
}
type State struct {
Outputs map[string]*OutputState `json:"outputs"`
TagCount uint32 `json:"tagCount"`
Layouts []string `json:"layouts"`
ActiveOutput string `json:"activeOutput"`
}
type cmd struct {
fn func()
}
type Manager struct {
display *wlclient.Display
registry *wlclient.Registry
manager interface{}
outputs map[uint32]*outputState
outputsMutex sync.RWMutex
tagCount uint32
layouts []string
wlMutex sync.Mutex
cmdq chan cmd
outputSetupReq chan uint32
stopChan chan struct{}
wg sync.WaitGroup
subscribers map[string]chan State
subMutex sync.RWMutex
dirty chan struct{}
notifierWg sync.WaitGroup
lastNotified *State
stateMutex sync.RWMutex
state *State
}
type outputState struct {
id uint32
registryName uint32
output *wlclient.Output
ipcOutput interface{}
name string
active uint32
tags []TagState
layout uint32
layoutSymbol string
title string
appID string
}
func (m *Manager) GetState() State {
m.stateMutex.RLock()
defer m.stateMutex.RUnlock()
if m.state == nil {
return State{
Outputs: make(map[string]*OutputState),
Layouts: []string{},
TagCount: 0,
}
}
stateCopy := *m.state
return stateCopy
}
func (m *Manager) Subscribe(id string) chan State {
ch := make(chan State, 64)
m.subMutex.Lock()
m.subscribers[id] = ch
m.subMutex.Unlock()
return ch
}
func (m *Manager) Unsubscribe(id string) {
m.subMutex.Lock()
if ch, ok := m.subscribers[id]; ok {
close(ch)
delete(m.subscribers, id)
}
m.subMutex.Unlock()
}
func (m *Manager) notifySubscribers() {
select {
case m.dirty <- struct{}{}:
default:
}
}
func stateChanged(old, new *State) bool {
if old == nil || new == nil {
return true
}
if old.TagCount != new.TagCount {
return true
}
if len(old.Layouts) != len(new.Layouts) {
return true
}
if old.ActiveOutput != new.ActiveOutput {
return true
}
if len(old.Outputs) != len(new.Outputs) {
return true
}
for name, newOut := range new.Outputs {
oldOut, exists := old.Outputs[name]
if !exists {
return true
}
if oldOut.Active != newOut.Active {
return true
}
if oldOut.Layout != newOut.Layout {
return true
}
if oldOut.LayoutSymbol != newOut.LayoutSymbol {
return true
}
if oldOut.Title != newOut.Title {
return true
}
if oldOut.AppID != newOut.AppID {
return true
}
if len(oldOut.Tags) != len(newOut.Tags) {
return true
}
for i, newTag := range newOut.Tags {
if i >= len(oldOut.Tags) {
return true
}
oldTag := oldOut.Tags[i]
if oldTag.Tag != newTag.Tag || oldTag.State != newTag.State ||
oldTag.Clients != newTag.Clients || oldTag.Focused != newTag.Focused {
return true
}
}
}
return false
}

View File

@@ -0,0 +1,128 @@
package freedesktop
import (
"context"
"fmt"
"os/exec"
"time"
"github.com/godbus/dbus/v5"
)
func (m *Manager) SetIconFile(iconPath string) error {
if !m.state.Accounts.Available || m.accountsObj == nil {
return fmt.Errorf("accounts service not available")
}
err := m.accountsObj.Call(dbusAccountsUserInterface+".SetIconFile", 0, iconPath).Err
if err != nil {
return fmt.Errorf("failed to set icon file: %w", err)
}
m.updateAccountsState()
return nil
}
func (m *Manager) SetRealName(name string) error {
if !m.state.Accounts.Available || m.accountsObj == nil {
return fmt.Errorf("accounts service not available")
}
err := m.accountsObj.Call(dbusAccountsUserInterface+".SetRealName", 0, name).Err
if err != nil {
return fmt.Errorf("failed to set real name: %w", err)
}
m.updateAccountsState()
return nil
}
func (m *Manager) SetEmail(email string) error {
if !m.state.Accounts.Available || m.accountsObj == nil {
return fmt.Errorf("accounts service not available")
}
err := m.accountsObj.Call(dbusAccountsUserInterface+".SetEmail", 0, email).Err
if err != nil {
return fmt.Errorf("failed to set email: %w", err)
}
m.updateAccountsState()
return nil
}
func (m *Manager) SetLanguage(language string) error {
if !m.state.Accounts.Available || m.accountsObj == nil {
return fmt.Errorf("accounts service not available")
}
err := m.accountsObj.Call(dbusAccountsUserInterface+".SetLanguage", 0, language).Err
if err != nil {
return fmt.Errorf("failed to set language: %w", err)
}
m.updateAccountsState()
return nil
}
func (m *Manager) SetLocation(location string) error {
if !m.state.Accounts.Available || m.accountsObj == nil {
return fmt.Errorf("accounts service not available")
}
err := m.accountsObj.Call(dbusAccountsUserInterface+".SetLocation", 0, location).Err
if err != nil {
return fmt.Errorf("failed to set location: %w", err)
}
m.updateAccountsState()
return nil
}
func (m *Manager) GetUserIconFile(username string) (string, error) {
if m.systemConn == nil {
return "", fmt.Errorf("accounts service not available")
}
accountsManager := m.systemConn.Object(dbusAccountsDest, dbus.ObjectPath(dbusAccountsPath))
var userPath dbus.ObjectPath
err := accountsManager.Call(dbusAccountsInterface+".FindUserByName", 0, username).Store(&userPath)
if err != nil {
return "", fmt.Errorf("user not found: %w", err)
}
userObj := m.systemConn.Object(dbusAccountsDest, userPath)
variant, err := userObj.GetProperty(dbusAccountsUserInterface + ".IconFile")
if err != nil {
return "", err
}
var iconFile string
if err := variant.Store(&iconFile); err != nil {
return "", err
}
return iconFile, nil
}
func (m *Manager) SetIconTheme(iconTheme string) error {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
check := exec.CommandContext(ctx, "gsettings", "writable", "org.gnome.desktop.interface", "icon-theme")
if err := check.Run(); err == nil {
cmd := exec.CommandContext(ctx, "gsettings", "set", "org.gnome.desktop.interface", "icon-theme", iconTheme)
if err := cmd.Run(); err != nil {
return fmt.Errorf("gsettings set failed: %w", err)
}
return nil
}
checkDconf := exec.CommandContext(ctx, "dconf", "write", "/org/gnome/desktop/interface/icon-theme", fmt.Sprintf("'%s'", iconTheme))
if err := checkDconf.Run(); err != nil {
return fmt.Errorf("both gsettings and dconf unavailable or failed: %w", err)
}
return nil
}

View File

@@ -0,0 +1,145 @@
package freedesktop
import (
"sync"
"testing"
"github.com/stretchr/testify/assert"
)
func TestManager_SetIconFile(t *testing.T) {
t.Run("accounts not available", func(t *testing.T) {
manager := &Manager{
state: &FreedeskState{
Accounts: AccountsState{
Available: false,
},
},
stateMutex: sync.RWMutex{},
}
err := manager.SetIconFile("/path/to/icon.png")
assert.Error(t, err)
assert.Contains(t, err.Error(), "accounts service not available")
})
}
func TestManager_SetRealName(t *testing.T) {
t.Run("accounts not available", func(t *testing.T) {
manager := &Manager{
state: &FreedeskState{
Accounts: AccountsState{
Available: false,
},
},
stateMutex: sync.RWMutex{},
}
err := manager.SetRealName("New Name")
assert.Error(t, err)
assert.Contains(t, err.Error(), "accounts service not available")
})
}
func TestManager_SetEmail(t *testing.T) {
t.Run("accounts not available", func(t *testing.T) {
manager := &Manager{
state: &FreedeskState{
Accounts: AccountsState{
Available: false,
},
},
stateMutex: sync.RWMutex{},
}
err := manager.SetEmail("test@example.com")
assert.Error(t, err)
assert.Contains(t, err.Error(), "accounts service not available")
})
}
func TestManager_SetLanguage(t *testing.T) {
t.Run("accounts not available", func(t *testing.T) {
manager := &Manager{
state: &FreedeskState{
Accounts: AccountsState{
Available: false,
},
},
stateMutex: sync.RWMutex{},
}
err := manager.SetLanguage("en_US.UTF-8")
assert.Error(t, err)
assert.Contains(t, err.Error(), "accounts service not available")
})
}
func TestManager_SetLocation(t *testing.T) {
t.Run("accounts not available", func(t *testing.T) {
manager := &Manager{
state: &FreedeskState{
Accounts: AccountsState{
Available: false,
},
},
stateMutex: sync.RWMutex{},
}
err := manager.SetLocation("Test Location")
assert.Error(t, err)
assert.Contains(t, err.Error(), "accounts service not available")
})
}
func TestManager_GetUserIconFile(t *testing.T) {
t.Run("accounts not available", func(t *testing.T) {
manager := &Manager{
state: &FreedeskState{
Accounts: AccountsState{
Available: false,
},
},
stateMutex: sync.RWMutex{},
}
iconFile, err := manager.GetUserIconFile("testuser")
assert.Error(t, err)
assert.Contains(t, err.Error(), "accounts service not available")
assert.Empty(t, iconFile)
})
}
func TestManager_UpdateAccountsState(t *testing.T) {
t.Run("accounts not available", func(t *testing.T) {
manager := &Manager{
state: &FreedeskState{
Accounts: AccountsState{
Available: false,
},
},
stateMutex: sync.RWMutex{},
}
err := manager.updateAccountsState()
assert.Error(t, err)
assert.Contains(t, err.Error(), "accounts service not available")
})
}
func TestManager_UpdateSettingsState(t *testing.T) {
t.Run("settings not available", func(t *testing.T) {
manager := &Manager{
state: &FreedeskState{
Settings: SettingsState{
Available: false,
},
},
stateMutex: sync.RWMutex{},
}
err := manager.updateSettingsState()
assert.Error(t, err)
assert.Contains(t, err.Error(), "settings portal not available")
})
}

View File

@@ -0,0 +1,14 @@
package freedesktop
const (
dbusAccountsDest = "org.freedesktop.Accounts"
dbusAccountsPath = "/org/freedesktop/Accounts"
dbusAccountsInterface = "org.freedesktop.Accounts"
dbusAccountsUserInterface = "org.freedesktop.Accounts.User"
dbusPortalDest = "org.freedesktop.portal.Desktop"
dbusPortalPath = "/org/freedesktop/portal/desktop"
dbusPortalSettingsInterface = "org.freedesktop.portal.Settings"
dbusPropsInterface = "org.freedesktop.DBus.Properties"
)

View File

@@ -0,0 +1,166 @@
package freedesktop
import (
"fmt"
"net"
"github.com/AvengeMedia/danklinux/internal/server/models"
)
type Request struct {
ID int `json:"id,omitempty"`
Method string `json:"method"`
Params map[string]interface{} `json:"params,omitempty"`
}
type SuccessResult struct {
Success bool `json:"success"`
Message string `json:"message"`
Value string `json:"value,omitempty"`
}
func HandleRequest(conn net.Conn, req Request, manager *Manager) {
switch req.Method {
case "freedesktop.getState":
handleGetState(conn, req, manager)
case "freedesktop.accounts.setIconFile":
handleSetIconFile(conn, req, manager)
case "freedesktop.accounts.setRealName":
handleSetRealName(conn, req, manager)
case "freedesktop.accounts.setEmail":
handleSetEmail(conn, req, manager)
case "freedesktop.accounts.setLanguage":
handleSetLanguage(conn, req, manager)
case "freedesktop.accounts.setLocation":
handleSetLocation(conn, req, manager)
case "freedesktop.accounts.getUserIconFile":
handleGetUserIconFile(conn, req, manager)
case "freedesktop.settings.getColorScheme":
handleGetColorScheme(conn, req, manager)
case "freedesktop.settings.setIconTheme":
handleSetIconTheme(conn, req, manager)
default:
models.RespondError(conn, req.ID, fmt.Sprintf("unknown method: %s", req.Method))
}
}
func handleGetState(conn net.Conn, req Request, manager *Manager) {
state := manager.GetState()
models.Respond(conn, req.ID, state)
}
func handleSetIconFile(conn net.Conn, req Request, manager *Manager) {
iconPath, ok := req.Params["path"].(string)
if !ok {
models.RespondError(conn, req.ID, "missing or invalid 'path' parameter")
return
}
if err := manager.SetIconFile(iconPath); err != nil {
models.RespondError(conn, req.ID, err.Error())
return
}
models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "icon file set"})
}
func handleSetRealName(conn net.Conn, req Request, manager *Manager) {
name, ok := req.Params["name"].(string)
if !ok {
models.RespondError(conn, req.ID, "missing or invalid 'name' parameter")
return
}
if err := manager.SetRealName(name); err != nil {
models.RespondError(conn, req.ID, err.Error())
return
}
models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "real name set"})
}
func handleSetEmail(conn net.Conn, req Request, manager *Manager) {
email, ok := req.Params["email"].(string)
if !ok {
models.RespondError(conn, req.ID, "missing or invalid 'email' parameter")
return
}
if err := manager.SetEmail(email); err != nil {
models.RespondError(conn, req.ID, err.Error())
return
}
models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "email set"})
}
func handleSetLanguage(conn net.Conn, req Request, manager *Manager) {
language, ok := req.Params["language"].(string)
if !ok {
models.RespondError(conn, req.ID, "missing or invalid 'language' parameter")
return
}
if err := manager.SetLanguage(language); err != nil {
models.RespondError(conn, req.ID, err.Error())
return
}
models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "language set"})
}
func handleSetLocation(conn net.Conn, req Request, manager *Manager) {
location, ok := req.Params["location"].(string)
if !ok {
models.RespondError(conn, req.ID, "missing or invalid 'location' parameter")
return
}
if err := manager.SetLocation(location); err != nil {
models.RespondError(conn, req.ID, err.Error())
return
}
models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "location set"})
}
func handleGetUserIconFile(conn net.Conn, req Request, manager *Manager) {
username, ok := req.Params["username"].(string)
if !ok {
models.RespondError(conn, req.ID, "missing or invalid 'username' parameter")
return
}
iconFile, err := manager.GetUserIconFile(username)
if err != nil {
models.RespondError(conn, req.ID, err.Error())
return
}
models.Respond(conn, req.ID, SuccessResult{Success: true, Value: iconFile})
}
func handleGetColorScheme(conn net.Conn, req Request, manager *Manager) {
if err := manager.updateSettingsState(); err != nil {
models.RespondError(conn, req.ID, err.Error())
return
}
state := manager.GetState()
models.Respond(conn, req.ID, map[string]uint32{"colorScheme": state.Settings.ColorScheme})
}
func handleSetIconTheme(conn net.Conn, req Request, manager *Manager) {
iconTheme, ok := req.Params["iconTheme"].(string)
if !ok {
models.RespondError(conn, req.ID, "missing or invalid 'iconTheme' parameter")
return
}
if err := manager.SetIconTheme(iconTheme); err != nil {
models.RespondError(conn, req.ID, err.Error())
return
}
models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "icon theme set"})
}

View File

@@ -0,0 +1,581 @@
package freedesktop
import (
"bytes"
"encoding/json"
"net"
"sync"
"testing"
mockdbus "github.com/AvengeMedia/danklinux/internal/mocks/github.com/godbus/dbus/v5"
"github.com/AvengeMedia/danklinux/internal/server/models"
"github.com/godbus/dbus/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
type mockNetConn struct {
net.Conn
readBuf *bytes.Buffer
writeBuf *bytes.Buffer
closed bool
}
func newMockNetConn() *mockNetConn {
return &mockNetConn{
readBuf: &bytes.Buffer{},
writeBuf: &bytes.Buffer{},
}
}
func (m *mockNetConn) Read(b []byte) (n int, err error) {
return m.readBuf.Read(b)
}
func (m *mockNetConn) Write(b []byte) (n int, err error) {
return m.writeBuf.Write(b)
}
func (m *mockNetConn) Close() error {
m.closed = true
return nil
}
func mockGetAllAccountsProperties() *dbus.Call {
props := map[string]dbus.Variant{
"IconFile": dbus.MakeVariant("/path/to/icon.png"),
"RealName": dbus.MakeVariant("Test"),
"UserName": dbus.MakeVariant("test"),
"AccountType": dbus.MakeVariant(int32(0)),
"HomeDirectory": dbus.MakeVariant("/home/test"),
"Shell": dbus.MakeVariant("/bin/bash"),
"Email": dbus.MakeVariant(""),
"Language": dbus.MakeVariant(""),
"Location": dbus.MakeVariant(""),
"Locked": dbus.MakeVariant(false),
"PasswordMode": dbus.MakeVariant(int32(1)),
}
return &dbus.Call{Err: nil, Body: []interface{}{props}}
}
func TestRespondError_Freedesktop(t *testing.T) {
conn := newMockNetConn()
models.RespondError(conn, 123, "test error")
var resp models.Response[any]
err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, 123, resp.ID)
assert.Equal(t, "test error", resp.Error)
assert.Nil(t, resp.Result)
}
func TestRespond_Freedesktop(t *testing.T) {
conn := newMockNetConn()
result := SuccessResult{Success: true, Message: "test"}
models.Respond(conn, 123, result)
var resp models.Response[SuccessResult]
err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, 123, resp.ID)
assert.Empty(t, resp.Error)
require.NotNil(t, resp.Result)
assert.True(t, resp.Result.Success)
assert.Equal(t, "test", resp.Result.Message)
}
func TestHandleGetState(t *testing.T) {
manager := &Manager{
state: &FreedeskState{
Accounts: AccountsState{
Available: true,
UserName: "testuser",
RealName: "Test User",
UID: 1000,
},
Settings: SettingsState{
Available: true,
ColorScheme: 1,
},
},
stateMutex: sync.RWMutex{},
}
conn := newMockNetConn()
req := Request{ID: 123, Method: "freedesktop.getState"}
handleGetState(conn, req, manager)
var resp models.Response[FreedeskState]
err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, 123, resp.ID)
assert.Empty(t, resp.Error)
require.NotNil(t, resp.Result)
assert.True(t, resp.Result.Accounts.Available)
assert.Equal(t, "testuser", resp.Result.Accounts.UserName)
assert.True(t, resp.Result.Settings.Available)
assert.Equal(t, uint32(1), resp.Result.Settings.ColorScheme)
}
func TestHandleSetIconFile(t *testing.T) {
t.Run("missing path parameter", func(t *testing.T) {
manager := &Manager{
state: &FreedeskState{},
stateMutex: sync.RWMutex{},
}
conn := newMockNetConn()
req := Request{
ID: 123,
Method: "freedesktop.accounts.setIconFile",
Params: map[string]interface{}{},
}
handleSetIconFile(conn, req, manager)
var resp models.Response[any]
err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, 123, resp.ID)
assert.Contains(t, resp.Error, "missing or invalid 'path' parameter")
})
t.Run("successful set icon file", func(t *testing.T) {
mockAccountsObj := mockdbus.NewMockBusObject(t)
mockCall := &dbus.Call{Err: nil}
mockAccountsObj.EXPECT().Call("org.freedesktop.Accounts.User.SetIconFile", dbus.Flags(0), "/path/to/icon.png").Return(mockCall)
mockAccountsObj.EXPECT().CallWithContext(mock.Anything, "org.freedesktop.DBus.Properties.GetAll", dbus.Flags(0), "org.freedesktop.Accounts.User").Return(mockGetAllAccountsProperties())
manager := &Manager{
state: &FreedeskState{
Accounts: AccountsState{
Available: true,
},
},
stateMutex: sync.RWMutex{},
accountsObj: mockAccountsObj,
}
conn := newMockNetConn()
req := Request{
ID: 123,
Method: "freedesktop.accounts.setIconFile",
Params: map[string]interface{}{
"path": "/path/to/icon.png",
},
}
handleSetIconFile(conn, req, manager)
var resp models.Response[SuccessResult]
err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, 123, resp.ID)
assert.Empty(t, resp.Error)
require.NotNil(t, resp.Result)
assert.True(t, resp.Result.Success)
assert.Equal(t, "icon file set", resp.Result.Message)
})
t.Run("accounts not available", func(t *testing.T) {
manager := &Manager{
state: &FreedeskState{
Accounts: AccountsState{
Available: false,
},
},
stateMutex: sync.RWMutex{},
}
conn := newMockNetConn()
req := Request{
ID: 123,
Method: "freedesktop.accounts.setIconFile",
Params: map[string]interface{}{
"path": "/path/to/icon.png",
},
}
handleSetIconFile(conn, req, manager)
var resp models.Response[SuccessResult]
err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, 123, resp.ID)
assert.Contains(t, resp.Error, "accounts service not available")
})
}
func TestHandleSetRealName(t *testing.T) {
t.Run("missing name parameter", func(t *testing.T) {
manager := &Manager{
state: &FreedeskState{},
stateMutex: sync.RWMutex{},
}
conn := newMockNetConn()
req := Request{
ID: 123,
Method: "freedesktop.accounts.setRealName",
Params: map[string]interface{}{},
}
handleSetRealName(conn, req, manager)
var resp models.Response[any]
err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, 123, resp.ID)
assert.Contains(t, resp.Error, "missing or invalid 'name' parameter")
})
t.Run("successful set real name", func(t *testing.T) {
mockAccountsObj := mockdbus.NewMockBusObject(t)
mockCall := &dbus.Call{Err: nil}
mockAccountsObj.EXPECT().Call("org.freedesktop.Accounts.User.SetRealName", dbus.Flags(0), "New Name").Return(mockCall)
mockAccountsObj.EXPECT().CallWithContext(mock.Anything, "org.freedesktop.DBus.Properties.GetAll", dbus.Flags(0), "org.freedesktop.Accounts.User").Return(mockGetAllAccountsProperties())
manager := &Manager{
state: &FreedeskState{
Accounts: AccountsState{
Available: true,
},
},
stateMutex: sync.RWMutex{},
accountsObj: mockAccountsObj,
}
conn := newMockNetConn()
req := Request{
ID: 123,
Method: "freedesktop.accounts.setRealName",
Params: map[string]interface{}{
"name": "New Name",
},
}
handleSetRealName(conn, req, manager)
var resp models.Response[SuccessResult]
err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, 123, resp.ID)
assert.Empty(t, resp.Error)
require.NotNil(t, resp.Result)
assert.True(t, resp.Result.Success)
assert.Equal(t, "real name set", resp.Result.Message)
})
}
func TestHandleSetEmail(t *testing.T) {
t.Run("missing email parameter", func(t *testing.T) {
manager := &Manager{
state: &FreedeskState{},
stateMutex: sync.RWMutex{},
}
conn := newMockNetConn()
req := Request{
ID: 123,
Method: "freedesktop.accounts.setEmail",
Params: map[string]interface{}{},
}
handleSetEmail(conn, req, manager)
var resp models.Response[any]
err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, 123, resp.ID)
assert.Contains(t, resp.Error, "missing or invalid 'email' parameter")
})
t.Run("successful set email", func(t *testing.T) {
mockAccountsObj := mockdbus.NewMockBusObject(t)
mockCall := &dbus.Call{Err: nil}
mockAccountsObj.EXPECT().Call("org.freedesktop.Accounts.User.SetEmail", dbus.Flags(0), "test@example.com").Return(mockCall)
mockAccountsObj.EXPECT().CallWithContext(mock.Anything, "org.freedesktop.DBus.Properties.GetAll", dbus.Flags(0), "org.freedesktop.Accounts.User").Return(mockGetAllAccountsProperties())
manager := &Manager{
state: &FreedeskState{
Accounts: AccountsState{
Available: true,
},
},
stateMutex: sync.RWMutex{},
accountsObj: mockAccountsObj,
}
conn := newMockNetConn()
req := Request{
ID: 123,
Method: "freedesktop.accounts.setEmail",
Params: map[string]interface{}{
"email": "test@example.com",
},
}
handleSetEmail(conn, req, manager)
var resp models.Response[SuccessResult]
err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, 123, resp.ID)
assert.Empty(t, resp.Error)
require.NotNil(t, resp.Result)
assert.True(t, resp.Result.Success)
assert.Equal(t, "email set", resp.Result.Message)
})
}
func TestHandleSetLanguage(t *testing.T) {
t.Run("missing language parameter", func(t *testing.T) {
manager := &Manager{
state: &FreedeskState{},
stateMutex: sync.RWMutex{},
}
conn := newMockNetConn()
req := Request{
ID: 123,
Method: "freedesktop.accounts.setLanguage",
Params: map[string]interface{}{},
}
handleSetLanguage(conn, req, manager)
var resp models.Response[any]
err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, 123, resp.ID)
assert.Contains(t, resp.Error, "missing or invalid 'language' parameter")
})
}
func TestHandleSetLocation(t *testing.T) {
t.Run("missing location parameter", func(t *testing.T) {
manager := &Manager{
state: &FreedeskState{},
stateMutex: sync.RWMutex{},
}
conn := newMockNetConn()
req := Request{
ID: 123,
Method: "freedesktop.accounts.setLocation",
Params: map[string]interface{}{},
}
handleSetLocation(conn, req, manager)
var resp models.Response[any]
err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, 123, resp.ID)
assert.Contains(t, resp.Error, "missing or invalid 'location' parameter")
})
}
func TestHandleGetUserIconFile(t *testing.T) {
t.Run("missing username parameter", func(t *testing.T) {
manager := &Manager{
state: &FreedeskState{},
stateMutex: sync.RWMutex{},
}
conn := newMockNetConn()
req := Request{
ID: 123,
Method: "freedesktop.accounts.getUserIconFile",
Params: map[string]interface{}{},
}
handleGetUserIconFile(conn, req, manager)
var resp models.Response[any]
err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, 123, resp.ID)
assert.Contains(t, resp.Error, "missing or invalid 'username' parameter")
})
t.Run("accounts not available", func(t *testing.T) {
manager := &Manager{
state: &FreedeskState{
Accounts: AccountsState{
Available: false,
},
},
stateMutex: sync.RWMutex{},
}
conn := newMockNetConn()
req := Request{
ID: 123,
Method: "freedesktop.accounts.getUserIconFile",
Params: map[string]interface{}{
"username": "testuser",
},
}
handleGetUserIconFile(conn, req, manager)
var resp models.Response[SuccessResult]
err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, 123, resp.ID)
assert.Contains(t, resp.Error, "accounts service not available")
})
}
func TestHandleGetColorScheme(t *testing.T) {
t.Run("settings not available", func(t *testing.T) {
manager := &Manager{
state: &FreedeskState{
Settings: SettingsState{
Available: false,
},
},
stateMutex: sync.RWMutex{},
}
conn := newMockNetConn()
req := Request{ID: 123, Method: "freedesktop.settings.getColorScheme"}
handleGetColorScheme(conn, req, manager)
var resp models.Response[map[string]uint32]
err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, 123, resp.ID)
assert.Contains(t, resp.Error, "settings portal not available")
})
t.Run("successful get color scheme", func(t *testing.T) {
mockSettingsObj := mockdbus.NewMockBusObject(t)
mockCall := &dbus.Call{
Err: nil,
Body: []interface{}{dbus.MakeVariant(uint32(1))},
}
mockSettingsObj.EXPECT().Call("org.freedesktop.portal.Settings.ReadOne", dbus.Flags(0), "org.freedesktop.appearance", "color-scheme").Return(mockCall)
manager := &Manager{
state: &FreedeskState{
Settings: SettingsState{
Available: true,
},
},
stateMutex: sync.RWMutex{},
settingsObj: mockSettingsObj,
}
conn := newMockNetConn()
req := Request{ID: 123, Method: "freedesktop.settings.getColorScheme"}
handleGetColorScheme(conn, req, manager)
var resp models.Response[map[string]uint32]
err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, 123, resp.ID)
assert.Empty(t, resp.Error)
require.NotNil(t, resp.Result)
assert.Equal(t, uint32(1), (*resp.Result)["colorScheme"])
})
}
func TestHandleRequest(t *testing.T) {
manager := &Manager{
state: &FreedeskState{
Accounts: AccountsState{
Available: true,
UserName: "testuser",
},
},
stateMutex: sync.RWMutex{},
}
t.Run("unknown method", func(t *testing.T) {
conn := newMockNetConn()
req := Request{
ID: 123,
Method: "freedesktop.unknown",
}
HandleRequest(conn, req, manager)
var resp models.Response[any]
err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, 123, resp.ID)
assert.Contains(t, resp.Error, "unknown method")
})
t.Run("valid method - getState", func(t *testing.T) {
conn := newMockNetConn()
req := Request{
ID: 123,
Method: "freedesktop.getState",
}
HandleRequest(conn, req, manager)
var resp models.Response[FreedeskState]
err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, 123, resp.ID)
assert.Empty(t, resp.Error)
})
t.Run("all method routes", func(t *testing.T) {
tests := []string{
"freedesktop.accounts.setIconFile",
"freedesktop.accounts.setRealName",
"freedesktop.accounts.setEmail",
"freedesktop.accounts.setLanguage",
"freedesktop.accounts.setLocation",
"freedesktop.accounts.getUserIconFile",
"freedesktop.settings.getColorScheme",
}
for _, method := range tests {
conn := newMockNetConn()
req := Request{
ID: 123,
Method: method,
Params: map[string]interface{}{},
}
HandleRequest(conn, req, manager)
var resp models.Response[any]
err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, 123, resp.ID)
// Will have errors due to missing params or service unavailable
// but the method routing should work
}
})
}

View File

@@ -0,0 +1,251 @@
package freedesktop
import (
"context"
"fmt"
"os"
"sync"
"github.com/godbus/dbus/v5"
)
func NewManager() (*Manager, error) {
systemConn, err := dbus.ConnectSystemBus()
if err != nil {
return nil, fmt.Errorf("failed to connect to system bus: %w", err)
}
sessionConn, err := dbus.ConnectSessionBus()
if err != nil {
sessionConn = nil
}
m := &Manager{
state: &FreedeskState{
Accounts: AccountsState{},
Settings: SettingsState{},
},
stateMutex: sync.RWMutex{},
systemConn: systemConn,
sessionConn: sessionConn,
currentUID: uint64(os.Getuid()),
subscribers: make(map[string]chan FreedeskState),
subMutex: sync.RWMutex{},
}
m.initializeAccounts()
m.initializeSettings()
return m, nil
}
func (m *Manager) initializeAccounts() error {
accountsManager := m.systemConn.Object(dbusAccountsDest, dbus.ObjectPath(dbusAccountsPath))
var userPath dbus.ObjectPath
err := accountsManager.Call(dbusAccountsInterface+".FindUserById", 0, int64(m.currentUID)).Store(&userPath)
if err != nil {
m.stateMutex.Lock()
m.state.Accounts.Available = false
m.stateMutex.Unlock()
return err
}
m.accountsObj = m.systemConn.Object(dbusAccountsDest, userPath)
m.stateMutex.Lock()
m.state.Accounts.Available = true
m.state.Accounts.UserPath = string(userPath)
m.state.Accounts.UID = m.currentUID
m.stateMutex.Unlock()
if err := m.updateAccountsState(); err != nil {
return fmt.Errorf("failed to update accounts state: %w", err)
}
return nil
}
func (m *Manager) initializeSettings() error {
if m.sessionConn == nil {
m.stateMutex.Lock()
m.state.Settings.Available = false
m.stateMutex.Unlock()
return fmt.Errorf("no session bus connection")
}
m.settingsObj = m.sessionConn.Object(dbusPortalDest, dbus.ObjectPath(dbusPortalPath))
var variant dbus.Variant
err := m.settingsObj.Call(dbusPortalSettingsInterface+".ReadOne", 0, "org.freedesktop.appearance", "color-scheme").Store(&variant)
if err != nil {
m.stateMutex.Lock()
m.state.Settings.Available = false
m.stateMutex.Unlock()
return err
}
m.stateMutex.Lock()
m.state.Settings.Available = true
m.stateMutex.Unlock()
if err := m.updateSettingsState(); err != nil {
return fmt.Errorf("failed to update settings state: %w", err)
}
return nil
}
func (m *Manager) updateAccountsState() error {
if !m.state.Accounts.Available || m.accountsObj == nil {
return fmt.Errorf("accounts service not available")
}
ctx := context.Background()
props, err := m.getAccountProperties(ctx)
if err != nil {
return err
}
m.stateMutex.Lock()
defer m.stateMutex.Unlock()
if v, ok := props["IconFile"]; ok {
if val, ok := v.Value().(string); ok {
m.state.Accounts.IconFile = val
}
}
if v, ok := props["RealName"]; ok {
if val, ok := v.Value().(string); ok {
m.state.Accounts.RealName = val
}
}
if v, ok := props["UserName"]; ok {
if val, ok := v.Value().(string); ok {
m.state.Accounts.UserName = val
}
}
if v, ok := props["AccountType"]; ok {
if val, ok := v.Value().(int32); ok {
m.state.Accounts.AccountType = val
}
}
if v, ok := props["HomeDirectory"]; ok {
if val, ok := v.Value().(string); ok {
m.state.Accounts.HomeDirectory = val
}
}
if v, ok := props["Shell"]; ok {
if val, ok := v.Value().(string); ok {
m.state.Accounts.Shell = val
}
}
if v, ok := props["Email"]; ok {
if val, ok := v.Value().(string); ok {
m.state.Accounts.Email = val
}
}
if v, ok := props["Language"]; ok {
if val, ok := v.Value().(string); ok {
m.state.Accounts.Language = val
}
}
if v, ok := props["Location"]; ok {
if val, ok := v.Value().(string); ok {
m.state.Accounts.Location = val
}
}
if v, ok := props["Locked"]; ok {
if val, ok := v.Value().(bool); ok {
m.state.Accounts.Locked = val
}
}
if v, ok := props["PasswordMode"]; ok {
if val, ok := v.Value().(int32); ok {
m.state.Accounts.PasswordMode = val
}
}
return nil
}
func (m *Manager) updateSettingsState() error {
if !m.state.Settings.Available || m.settingsObj == nil {
return fmt.Errorf("settings portal not available")
}
var variant dbus.Variant
err := m.settingsObj.Call(dbusPortalSettingsInterface+".ReadOne", 0, "org.freedesktop.appearance", "color-scheme").Store(&variant)
if err != nil {
return err
}
if colorScheme, ok := variant.Value().(uint32); ok {
m.stateMutex.Lock()
m.state.Settings.ColorScheme = colorScheme
m.stateMutex.Unlock()
}
return nil
}
func (m *Manager) getAccountProperties(ctx context.Context) (map[string]dbus.Variant, error) {
var props map[string]dbus.Variant
err := m.accountsObj.CallWithContext(ctx, dbusPropsInterface+".GetAll", 0, dbusAccountsUserInterface).Store(&props)
if err != nil {
return nil, err
}
return props, nil
}
func (m *Manager) GetState() FreedeskState {
m.stateMutex.RLock()
defer m.stateMutex.RUnlock()
return *m.state
}
func (m *Manager) Subscribe(id string) chan FreedeskState {
ch := make(chan FreedeskState, 64)
m.subMutex.Lock()
m.subscribers[id] = ch
m.subMutex.Unlock()
return ch
}
func (m *Manager) Unsubscribe(id string) {
m.subMutex.Lock()
if ch, ok := m.subscribers[id]; ok {
close(ch)
delete(m.subscribers, id)
}
m.subMutex.Unlock()
}
func (m *Manager) NotifySubscribers() {
m.subMutex.RLock()
defer m.subMutex.RUnlock()
state := m.GetState()
for _, ch := range m.subscribers {
select {
case ch <- state:
default:
}
}
}
func (m *Manager) Close() {
m.subMutex.Lock()
for id, ch := range m.subscribers {
close(ch)
delete(m.subscribers, id)
}
m.subMutex.Unlock()
if m.systemConn != nil {
m.systemConn.Close()
}
if m.sessionConn != nil {
m.sessionConn.Close()
}
}

View File

@@ -0,0 +1,143 @@
package freedesktop
import (
"sync"
"testing"
"github.com/stretchr/testify/assert"
)
func TestManager_GetState(t *testing.T) {
state := &FreedeskState{
Accounts: AccountsState{
Available: true,
UserName: "testuser",
RealName: "Test User",
UID: 1000,
},
Settings: SettingsState{
Available: true,
ColorScheme: 1,
},
}
manager := &Manager{
state: state,
stateMutex: sync.RWMutex{},
}
result := manager.GetState()
assert.True(t, result.Accounts.Available)
assert.Equal(t, "testuser", result.Accounts.UserName)
assert.Equal(t, "Test User", result.Accounts.RealName)
assert.Equal(t, uint64(1000), result.Accounts.UID)
assert.True(t, result.Settings.Available)
assert.Equal(t, uint32(1), result.Settings.ColorScheme)
}
func TestManager_GetState_ThreadSafe(t *testing.T) {
manager := &Manager{
state: &FreedeskState{
Accounts: AccountsState{
Available: true,
UserName: "testuser",
},
Settings: SettingsState{
Available: true,
ColorScheme: 1,
},
},
stateMutex: sync.RWMutex{},
}
done := make(chan bool)
for i := 0; i < 10; i++ {
go func() {
state := manager.GetState()
assert.True(t, state.Accounts.Available)
assert.Equal(t, "testuser", state.Accounts.UserName)
done <- true
}()
}
for i := 0; i < 10; i++ {
<-done
}
}
func TestManager_Close(t *testing.T) {
manager := &Manager{
state: &FreedeskState{},
stateMutex: sync.RWMutex{},
systemConn: nil,
sessionConn: nil,
}
assert.NotPanics(t, func() {
manager.Close()
})
}
func TestNewManager(t *testing.T) {
t.Run("attempts to create manager", func(t *testing.T) {
manager, err := NewManager()
if err != nil {
assert.Nil(t, manager)
} else {
assert.NotNil(t, manager)
assert.NotNil(t, manager.state)
assert.NotNil(t, manager.systemConn)
manager.Close()
}
})
}
func TestManager_GetState_EmptyState(t *testing.T) {
manager := &Manager{
state: &FreedeskState{},
stateMutex: sync.RWMutex{},
}
result := manager.GetState()
assert.False(t, result.Accounts.Available)
assert.Empty(t, result.Accounts.UserName)
assert.False(t, result.Settings.Available)
assert.Equal(t, uint32(0), result.Settings.ColorScheme)
}
func TestManager_AccountsState_Modification(t *testing.T) {
manager := &Manager{
state: &FreedeskState{
Accounts: AccountsState{
Available: true,
UserName: "testuser",
},
},
stateMutex: sync.RWMutex{},
}
state := manager.GetState()
state.Accounts.UserName = "modifieduser"
original := manager.GetState()
assert.Equal(t, "testuser", original.Accounts.UserName)
}
func TestManager_SettingsState_Modification(t *testing.T) {
manager := &Manager{
state: &FreedeskState{
Settings: SettingsState{
Available: true,
ColorScheme: 0,
},
},
stateMutex: sync.RWMutex{},
}
state := manager.GetState()
state.Settings.ColorScheme = 1
original := manager.GetState()
assert.Equal(t, uint32(0), original.Settings.ColorScheme)
}

View File

@@ -0,0 +1,46 @@
package freedesktop
import (
"sync"
"github.com/godbus/dbus/v5"
)
type AccountsState struct {
Available bool `json:"available"`
UserPath string `json:"userPath"`
IconFile string `json:"iconFile"`
RealName string `json:"realName"`
UserName string `json:"userName"`
AccountType int32 `json:"accountType"`
HomeDirectory string `json:"homeDirectory"`
Shell string `json:"shell"`
Email string `json:"email"`
Language string `json:"language"`
Location string `json:"location"`
Locked bool `json:"locked"`
PasswordMode int32 `json:"passwordMode"`
UID uint64 `json:"uid"`
}
type SettingsState struct {
Available bool `json:"available"`
ColorScheme uint32 `json:"colorScheme"`
}
type FreedeskState struct {
Accounts AccountsState `json:"accounts"`
Settings SettingsState `json:"settings"`
}
type Manager struct {
state *FreedeskState
stateMutex sync.RWMutex
systemConn *dbus.Conn
sessionConn *dbus.Conn
accountsObj dbus.BusObject
settingsObj dbus.BusObject
currentUID uint64
subscribers map[string]chan FreedeskState
subMutex sync.RWMutex
}

View File

@@ -0,0 +1,70 @@
package freedesktop
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestAccountsState_Struct(t *testing.T) {
state := AccountsState{
Available: true,
UserPath: "/org/freedesktop/Accounts/User1000",
RealName: "Test User",
UserName: "testuser",
Locked: false,
UID: 1000,
}
assert.True(t, state.Available)
assert.Equal(t, "/org/freedesktop/Accounts/User1000", state.UserPath)
assert.Equal(t, "Test User", state.RealName)
assert.Equal(t, "testuser", state.UserName)
assert.Equal(t, uint64(1000), state.UID)
assert.False(t, state.Locked)
}
func TestSettingsState_Struct(t *testing.T) {
state := SettingsState{
Available: true,
ColorScheme: 1, // Dark mode
}
assert.True(t, state.Available)
assert.Equal(t, uint32(1), state.ColorScheme)
}
func TestFreedeskState_Struct(t *testing.T) {
state := FreedeskState{
Accounts: AccountsState{
Available: true,
UserName: "testuser",
UID: 1000,
},
Settings: SettingsState{
Available: true,
ColorScheme: 0, // Light mode
},
}
assert.True(t, state.Accounts.Available)
assert.Equal(t, "testuser", state.Accounts.UserName)
assert.True(t, state.Settings.Available)
assert.Equal(t, uint32(0), state.Settings.ColorScheme)
}
func TestAccountsState_DefaultValues(t *testing.T) {
state := AccountsState{}
assert.False(t, state.Available)
assert.Empty(t, state.UserPath)
assert.Empty(t, state.UserName)
assert.Equal(t, uint64(0), state.UID)
}
func TestSettingsState_DefaultValues(t *testing.T) {
state := SettingsState{}
assert.False(t, state.Available)
assert.Equal(t, uint32(0), state.ColorScheme)
}

View File

@@ -0,0 +1,88 @@
package loginctl
import (
"fmt"
)
func (m *Manager) Lock() error {
if m.sessionObj == nil {
return fmt.Errorf("session object not available")
}
err := m.sessionObj.Call(dbusSessionInterface+".Lock", 0).Err
if err != nil {
if refreshErr := m.refreshSessionBinding(); refreshErr == nil {
err = m.sessionObj.Call(dbusSessionInterface+".Lock", 0).Err
}
if err != nil {
return fmt.Errorf("failed to lock session: %w", err)
}
}
return nil
}
func (m *Manager) Unlock() error {
err := m.sessionObj.Call(dbusSessionInterface+".Unlock", 0).Err
if err != nil {
if refreshErr := m.refreshSessionBinding(); refreshErr == nil {
err = m.sessionObj.Call(dbusSessionInterface+".Unlock", 0).Err
}
if err != nil {
return fmt.Errorf("failed to unlock session: %w", err)
}
}
return nil
}
func (m *Manager) Activate() error {
err := m.sessionObj.Call(dbusSessionInterface+".Activate", 0).Err
if err != nil {
if refreshErr := m.refreshSessionBinding(); refreshErr == nil {
err = m.sessionObj.Call(dbusSessionInterface+".Activate", 0).Err
}
if err != nil {
return fmt.Errorf("failed to activate session: %w", err)
}
}
return nil
}
func (m *Manager) SetIdleHint(idle bool) error {
err := m.sessionObj.Call(dbusSessionInterface+".SetIdleHint", 0, idle).Err
if err != nil {
if refreshErr := m.refreshSessionBinding(); refreshErr == nil {
err = m.sessionObj.Call(dbusSessionInterface+".SetIdleHint", 0, idle).Err
}
if err != nil {
return fmt.Errorf("failed to set idle hint: %w", err)
}
}
return nil
}
func (m *Manager) Terminate() error {
err := m.sessionObj.Call(dbusSessionInterface+".Terminate", 0).Err
if err != nil {
if refreshErr := m.refreshSessionBinding(); refreshErr == nil {
err = m.sessionObj.Call(dbusSessionInterface+".Terminate", 0).Err
}
if err != nil {
return fmt.Errorf("failed to terminate session: %w", err)
}
}
return nil
}
func (m *Manager) SetLockBeforeSuspend(enabled bool) {
m.lockBeforeSuspend.Store(enabled)
}
func (m *Manager) SetSleepInhibitorEnabled(enabled bool) {
m.sleepInhibitorEnabled.Store(enabled)
if enabled {
// Re-acquire inhibitor if enabled
m.acquireSleepInhibitor()
} else {
// Release inhibitor if disabled
m.releaseSleepInhibitor()
}
}

View File

@@ -0,0 +1,9 @@
package loginctl
const (
dbusDest = "org.freedesktop.login1"
dbusPath = "/org/freedesktop/login1"
dbusManagerInterface = "org.freedesktop.login1.Manager"
dbusSessionInterface = "org.freedesktop.login1.Session"
dbusPropsInterface = "org.freedesktop.DBus.Properties"
)

View File

@@ -0,0 +1,167 @@
package loginctl
import (
"encoding/json"
"fmt"
"net"
"github.com/AvengeMedia/danklinux/internal/server/models"
)
type Request struct {
ID int `json:"id,omitempty"`
Method string `json:"method"`
Params map[string]interface{} `json:"params,omitempty"`
}
type SuccessResult struct {
Success bool `json:"success"`
Message string `json:"message"`
}
func HandleRequest(conn net.Conn, req Request, manager *Manager) {
switch req.Method {
case "loginctl.getState":
handleGetState(conn, req, manager)
case "loginctl.lock":
handleLock(conn, req, manager)
case "loginctl.unlock":
handleUnlock(conn, req, manager)
case "loginctl.activate":
handleActivate(conn, req, manager)
case "loginctl.setIdleHint":
handleSetIdleHint(conn, req, manager)
case "loginctl.setLockBeforeSuspend":
handleSetLockBeforeSuspend(conn, req, manager)
case "loginctl.setSleepInhibitorEnabled":
handleSetSleepInhibitorEnabled(conn, req, manager)
case "loginctl.lockerReady":
handleLockerReady(conn, req, manager)
case "loginctl.terminate":
handleTerminate(conn, req, manager)
case "loginctl.subscribe":
handleSubscribe(conn, req, manager)
default:
models.RespondError(conn, req.ID, fmt.Sprintf("unknown method: %s", req.Method))
}
}
func handleGetState(conn net.Conn, req Request, manager *Manager) {
state := manager.GetState()
models.Respond(conn, req.ID, state)
}
func handleLock(conn net.Conn, req Request, manager *Manager) {
if err := manager.Lock(); err != nil {
models.RespondError(conn, req.ID, err.Error())
return
}
models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "locked"})
}
func handleUnlock(conn net.Conn, req Request, manager *Manager) {
if err := manager.Unlock(); err != nil {
models.RespondError(conn, req.ID, err.Error())
return
}
models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "unlocked"})
}
func handleActivate(conn net.Conn, req Request, manager *Manager) {
if err := manager.Activate(); err != nil {
models.RespondError(conn, req.ID, err.Error())
return
}
models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "activated"})
}
func handleSetIdleHint(conn net.Conn, req Request, manager *Manager) {
idle, ok := req.Params["idle"].(bool)
if !ok {
models.RespondError(conn, req.ID, "missing or invalid 'idle' parameter")
return
}
if err := manager.SetIdleHint(idle); err != nil {
models.RespondError(conn, req.ID, err.Error())
return
}
models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "idle hint set"})
}
func handleSetLockBeforeSuspend(conn net.Conn, req Request, manager *Manager) {
enabled, ok := req.Params["enabled"].(bool)
if !ok {
models.RespondError(conn, req.ID, "missing or invalid 'enabled' parameter")
return
}
manager.SetLockBeforeSuspend(enabled)
models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "lock before suspend set"})
}
func handleSetSleepInhibitorEnabled(conn net.Conn, req Request, manager *Manager) {
enabled, ok := req.Params["enabled"].(bool)
if !ok {
models.RespondError(conn, req.ID, "missing or invalid 'enabled' parameter")
return
}
manager.SetSleepInhibitorEnabled(enabled)
models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "sleep inhibitor setting updated"})
}
func handleLockerReady(conn net.Conn, req Request, manager *Manager) {
manager.lockTimerMu.Lock()
if manager.lockTimer != nil {
manager.lockTimer.Stop()
manager.lockTimer = nil
}
manager.lockTimerMu.Unlock()
id := manager.sleepCycleID.Load()
manager.releaseForCycle(id)
if manager.inSleepCycle.Load() {
manager.signalLockerReady()
}
models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "ok"})
}
func handleTerminate(conn net.Conn, req Request, manager *Manager) {
if err := manager.Terminate(); err != nil {
models.RespondError(conn, req.ID, err.Error())
return
}
models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "terminated"})
}
func handleSubscribe(conn net.Conn, req Request, manager *Manager) {
clientID := fmt.Sprintf("client-%p", conn)
stateChan := manager.Subscribe(clientID)
defer manager.Unsubscribe(clientID)
initialState := manager.GetState()
event := SessionEvent{
Type: EventStateChanged,
Data: initialState,
}
if err := json.NewEncoder(conn).Encode(models.Response[SessionEvent]{
ID: req.ID,
Result: &event,
}); err != nil {
return
}
for state := range stateChan {
event := SessionEvent{
Type: EventStateChanged,
Data: state,
}
if err := json.NewEncoder(conn).Encode(models.Response[SessionEvent]{
Result: &event,
}); err != nil {
return
}
}
}

View File

@@ -0,0 +1,502 @@
package loginctl
import (
"bytes"
"encoding/json"
"net"
"sync"
"testing"
"time"
mockdbus "github.com/AvengeMedia/danklinux/internal/mocks/github.com/godbus/dbus/v5"
"github.com/AvengeMedia/danklinux/internal/server/models"
"github.com/godbus/dbus/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
type mockNetConn struct {
net.Conn
readBuf *bytes.Buffer
writeBuf *bytes.Buffer
closed bool
}
func newMockNetConn() *mockNetConn {
return &mockNetConn{
readBuf: &bytes.Buffer{},
writeBuf: &bytes.Buffer{},
}
}
func (m *mockNetConn) Read(b []byte) (n int, err error) {
return m.readBuf.Read(b)
}
func (m *mockNetConn) Write(b []byte) (n int, err error) {
return m.writeBuf.Write(b)
}
func (m *mockNetConn) Close() error {
m.closed = true
return nil
}
func TestRespondError_Loginctl(t *testing.T) {
conn := newMockNetConn()
models.RespondError(conn, 123, "test error")
var resp models.Response[any]
err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, 123, resp.ID)
assert.Equal(t, "test error", resp.Error)
assert.Nil(t, resp.Result)
}
func TestRespond_Loginctl(t *testing.T) {
conn := newMockNetConn()
result := SuccessResult{Success: true, Message: "test"}
models.Respond(conn, 123, result)
var resp models.Response[SuccessResult]
err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, 123, resp.ID)
assert.Empty(t, resp.Error)
require.NotNil(t, resp.Result)
assert.True(t, resp.Result.Success)
assert.Equal(t, "test", resp.Result.Message)
}
func TestHandleGetState(t *testing.T) {
manager := &Manager{
state: &SessionState{
SessionID: "1",
Locked: false,
Active: true,
SessionType: "wayland",
SessionClass: "user",
UserName: "testuser",
},
stateMutex: sync.RWMutex{},
}
conn := newMockNetConn()
req := Request{ID: 123, Method: "loginctl.getState"}
handleGetState(conn, req, manager)
var resp models.Response[SessionState]
err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, 123, resp.ID)
assert.Empty(t, resp.Error)
require.NotNil(t, resp.Result)
assert.Equal(t, "1", resp.Result.SessionID)
assert.False(t, resp.Result.Locked)
assert.True(t, resp.Result.Active)
}
func TestHandleLock(t *testing.T) {
t.Run("successful lock", func(t *testing.T) {
mockSessionObj := mockdbus.NewMockBusObject(t)
mockCall := &dbus.Call{Err: nil}
mockSessionObj.EXPECT().Call("org.freedesktop.login1.Session.Lock", dbus.Flags(0)).Return(mockCall)
manager := &Manager{
state: &SessionState{},
stateMutex: sync.RWMutex{},
sessionObj: mockSessionObj,
}
conn := newMockNetConn()
req := Request{ID: 123, Method: "loginctl.lock"}
handleLock(conn, req, manager)
var resp models.Response[SuccessResult]
err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, 123, resp.ID)
assert.Empty(t, resp.Error)
require.NotNil(t, resp.Result)
assert.True(t, resp.Result.Success)
assert.Equal(t, "locked", resp.Result.Message)
})
t.Run("lock fails", func(t *testing.T) {
mockSessionObj := mockdbus.NewMockBusObject(t)
mockCall := &dbus.Call{Err: assert.AnError}
mockSessionObj.EXPECT().Call("org.freedesktop.login1.Session.Lock", dbus.Flags(0)).Return(mockCall)
manager := &Manager{
state: &SessionState{},
stateMutex: sync.RWMutex{},
sessionObj: mockSessionObj,
}
conn := newMockNetConn()
req := Request{ID: 123, Method: "loginctl.lock"}
handleLock(conn, req, manager)
var resp models.Response[SuccessResult]
err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, 123, resp.ID)
assert.Contains(t, resp.Error, "failed to lock session")
})
}
func TestHandleUnlock(t *testing.T) {
t.Run("successful unlock", func(t *testing.T) {
mockSessionObj := mockdbus.NewMockBusObject(t)
mockCall := &dbus.Call{Err: nil}
mockSessionObj.EXPECT().Call("org.freedesktop.login1.Session.Unlock", dbus.Flags(0)).Return(mockCall)
manager := &Manager{
state: &SessionState{},
stateMutex: sync.RWMutex{},
sessionObj: mockSessionObj,
}
conn := newMockNetConn()
req := Request{ID: 123, Method: "loginctl.unlock"}
handleUnlock(conn, req, manager)
var resp models.Response[SuccessResult]
err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, 123, resp.ID)
assert.Empty(t, resp.Error)
require.NotNil(t, resp.Result)
assert.True(t, resp.Result.Success)
assert.Equal(t, "unlocked", resp.Result.Message)
})
t.Run("unlock fails", func(t *testing.T) {
mockSessionObj := mockdbus.NewMockBusObject(t)
mockCall := &dbus.Call{Err: assert.AnError}
mockSessionObj.EXPECT().Call("org.freedesktop.login1.Session.Unlock", dbus.Flags(0)).Return(mockCall)
manager := &Manager{
state: &SessionState{},
stateMutex: sync.RWMutex{},
sessionObj: mockSessionObj,
}
conn := newMockNetConn()
req := Request{ID: 123, Method: "loginctl.unlock"}
handleUnlock(conn, req, manager)
var resp models.Response[SuccessResult]
err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, 123, resp.ID)
assert.Contains(t, resp.Error, "failed to unlock session")
})
}
func TestHandleActivate(t *testing.T) {
t.Run("successful activate", func(t *testing.T) {
mockSessionObj := mockdbus.NewMockBusObject(t)
mockCall := &dbus.Call{Err: nil}
mockSessionObj.EXPECT().Call("org.freedesktop.login1.Session.Activate", dbus.Flags(0)).Return(mockCall)
manager := &Manager{
state: &SessionState{},
stateMutex: sync.RWMutex{},
sessionObj: mockSessionObj,
}
conn := newMockNetConn()
req := Request{ID: 123, Method: "loginctl.activate"}
handleActivate(conn, req, manager)
var resp models.Response[SuccessResult]
err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, 123, resp.ID)
assert.Empty(t, resp.Error)
require.NotNil(t, resp.Result)
assert.True(t, resp.Result.Success)
assert.Equal(t, "activated", resp.Result.Message)
})
t.Run("activate fails", func(t *testing.T) {
mockSessionObj := mockdbus.NewMockBusObject(t)
mockCall := &dbus.Call{Err: assert.AnError}
mockSessionObj.EXPECT().Call("org.freedesktop.login1.Session.Activate", dbus.Flags(0)).Return(mockCall)
manager := &Manager{
state: &SessionState{},
stateMutex: sync.RWMutex{},
sessionObj: mockSessionObj,
}
conn := newMockNetConn()
req := Request{ID: 123, Method: "loginctl.activate"}
handleActivate(conn, req, manager)
var resp models.Response[SuccessResult]
err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, 123, resp.ID)
assert.Contains(t, resp.Error, "failed to activate session")
})
}
func TestHandleSetIdleHint(t *testing.T) {
t.Run("missing idle parameter", func(t *testing.T) {
manager := &Manager{
state: &SessionState{},
stateMutex: sync.RWMutex{},
}
conn := newMockNetConn()
req := Request{
ID: 123,
Method: "loginctl.setIdleHint",
Params: map[string]interface{}{},
}
handleSetIdleHint(conn, req, manager)
var resp models.Response[any]
err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, 123, resp.ID)
assert.Contains(t, resp.Error, "missing or invalid 'idle' parameter")
})
t.Run("successful set idle hint true", func(t *testing.T) {
mockSessionObj := mockdbus.NewMockBusObject(t)
mockCall := &dbus.Call{Err: nil}
mockSessionObj.EXPECT().Call("org.freedesktop.login1.Session.SetIdleHint", dbus.Flags(0), true).Return(mockCall)
manager := &Manager{
state: &SessionState{},
stateMutex: sync.RWMutex{},
sessionObj: mockSessionObj,
}
conn := newMockNetConn()
req := Request{
ID: 123,
Method: "loginctl.setIdleHint",
Params: map[string]interface{}{
"idle": true,
},
}
handleSetIdleHint(conn, req, manager)
var resp models.Response[SuccessResult]
err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, 123, resp.ID)
assert.Empty(t, resp.Error)
require.NotNil(t, resp.Result)
assert.True(t, resp.Result.Success)
assert.Equal(t, "idle hint set", resp.Result.Message)
})
t.Run("set idle hint fails", func(t *testing.T) {
mockSessionObj := mockdbus.NewMockBusObject(t)
mockCall := &dbus.Call{Err: assert.AnError}
mockSessionObj.EXPECT().Call("org.freedesktop.login1.Session.SetIdleHint", dbus.Flags(0), false).Return(mockCall)
manager := &Manager{
state: &SessionState{},
stateMutex: sync.RWMutex{},
sessionObj: mockSessionObj,
}
conn := newMockNetConn()
req := Request{
ID: 123,
Method: "loginctl.setIdleHint",
Params: map[string]interface{}{
"idle": false,
},
}
handleSetIdleHint(conn, req, manager)
var resp models.Response[SuccessResult]
err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, 123, resp.ID)
assert.Contains(t, resp.Error, "failed to set idle hint")
})
}
func TestHandleTerminate(t *testing.T) {
t.Run("successful terminate", func(t *testing.T) {
mockSessionObj := mockdbus.NewMockBusObject(t)
mockCall := &dbus.Call{Err: nil}
mockSessionObj.EXPECT().Call("org.freedesktop.login1.Session.Terminate", dbus.Flags(0)).Return(mockCall)
manager := &Manager{
state: &SessionState{},
stateMutex: sync.RWMutex{},
sessionObj: mockSessionObj,
}
conn := newMockNetConn()
req := Request{ID: 123, Method: "loginctl.terminate"}
handleTerminate(conn, req, manager)
var resp models.Response[SuccessResult]
err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, 123, resp.ID)
assert.Empty(t, resp.Error)
require.NotNil(t, resp.Result)
assert.True(t, resp.Result.Success)
assert.Equal(t, "terminated", resp.Result.Message)
})
t.Run("terminate fails", func(t *testing.T) {
mockSessionObj := mockdbus.NewMockBusObject(t)
mockCall := &dbus.Call{Err: assert.AnError}
mockSessionObj.EXPECT().Call("org.freedesktop.login1.Session.Terminate", dbus.Flags(0)).Return(mockCall)
manager := &Manager{
state: &SessionState{},
stateMutex: sync.RWMutex{},
sessionObj: mockSessionObj,
}
conn := newMockNetConn()
req := Request{ID: 123, Method: "loginctl.terminate"}
handleTerminate(conn, req, manager)
var resp models.Response[SuccessResult]
err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, 123, resp.ID)
assert.Contains(t, resp.Error, "failed to terminate session")
})
}
func TestHandleRequest(t *testing.T) {
manager := &Manager{
state: &SessionState{
SessionID: "1",
Locked: false,
},
stateMutex: sync.RWMutex{},
}
t.Run("unknown method", func(t *testing.T) {
conn := newMockNetConn()
req := Request{
ID: 123,
Method: "loginctl.unknown",
}
HandleRequest(conn, req, manager)
var resp models.Response[any]
err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, 123, resp.ID)
assert.Contains(t, resp.Error, "unknown method")
})
t.Run("valid method - getState", func(t *testing.T) {
conn := newMockNetConn()
req := Request{
ID: 123,
Method: "loginctl.getState",
}
HandleRequest(conn, req, manager)
var resp models.Response[SessionState]
err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, 123, resp.ID)
assert.Empty(t, resp.Error)
})
t.Run("lock method", func(t *testing.T) {
mockSessionObj := mockdbus.NewMockBusObject(t)
mockCall := &dbus.Call{Err: nil}
mockSessionObj.EXPECT().Call("org.freedesktop.login1.Session.Lock", mock.Anything).Return(mockCall)
manager.sessionObj = mockSessionObj
conn := newMockNetConn()
req := Request{
ID: 123,
Method: "loginctl.lock",
}
HandleRequest(conn, req, manager)
var resp models.Response[any]
err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, 123, resp.ID)
})
}
func TestHandleSubscribe(t *testing.T) {
manager := &Manager{
state: &SessionState{
SessionID: "1",
Locked: false,
},
stateMutex: sync.RWMutex{},
subscribers: make(map[string]chan SessionState),
subMutex: sync.RWMutex{},
}
conn := newMockNetConn()
req := Request{ID: 123, Method: "loginctl.subscribe"}
done := make(chan bool)
go func() {
handleSubscribe(conn, req, manager)
done <- true
}()
time.Sleep(50 * time.Millisecond)
conn.Close()
if conn.writeBuf.Len() > 0 {
var resp models.Response[SessionEvent]
err := json.NewDecoder(conn.writeBuf).Decode(&resp)
if err == nil {
assert.Equal(t, 123, resp.ID)
require.NotNil(t, resp.Result)
assert.Equal(t, EventStateChanged, resp.Result.Type)
assert.Equal(t, "1", resp.Result.Data.SessionID)
}
}
select {
case <-done:
case <-time.After(100 * time.Millisecond):
}
}

View File

@@ -0,0 +1,597 @@
package loginctl
import (
"context"
"fmt"
"os"
"sync"
"time"
"github.com/godbus/dbus/v5"
)
func NewManager() (*Manager, error) {
conn, err := dbus.ConnectSystemBus()
if err != nil {
return nil, fmt.Errorf("failed to connect to system bus: %w", err)
}
sessionID := os.Getenv("XDG_SESSION_ID")
if sessionID == "" {
sessionID = "self"
}
m := &Manager{
state: &SessionState{
SessionID: sessionID,
},
stateMutex: sync.RWMutex{},
subscribers: make(map[string]chan SessionState),
subMutex: sync.RWMutex{},
stopChan: make(chan struct{}),
conn: conn,
dirty: make(chan struct{}, 1),
signals: make(chan *dbus.Signal, 256),
}
m.sleepInhibitorEnabled.Store(true)
if err := m.initialize(); err != nil {
conn.Close()
return nil, err
}
if err := m.acquireSleepInhibitor(); err != nil {
fmt.Fprintf(os.Stderr, "sleep inhibitor unavailable: %v\n", err)
}
m.notifierWg.Add(1)
go m.notifier()
if err := m.startSignalPump(); err != nil {
m.Close()
return nil, err
}
return m, nil
}
func (m *Manager) initialize() error {
m.managerObj = m.conn.Object(dbusDest, dbus.ObjectPath(dbusPath))
m.initializeFallbackDelay()
sessionPath, err := m.getSession(m.state.SessionID)
if err != nil {
return fmt.Errorf("failed to get session path: %w", err)
}
m.stateMutex.Lock()
m.state.SessionPath = string(sessionPath)
m.sessionPath = sessionPath
m.stateMutex.Unlock()
m.sessionObj = m.conn.Object(dbusDest, sessionPath)
if err := m.updateSessionState(); err != nil {
return err
}
return nil
}
func (m *Manager) getSession(id string) (dbus.ObjectPath, error) {
var out dbus.ObjectPath
err := m.managerObj.Call(dbusManagerInterface+".GetSession", 0, id).Store(&out)
if err != nil {
return "", err
}
return out, nil
}
func (m *Manager) refreshSessionBinding() error {
if m.managerObj == nil || m.conn == nil {
return fmt.Errorf("manager not fully initialized")
}
sessionPath, err := m.getSession(m.state.SessionID)
if err != nil {
return fmt.Errorf("failed to get session path: %w", err)
}
m.stateMutex.RLock()
currentPath := m.sessionPath
m.stateMutex.RUnlock()
if sessionPath == currentPath {
return nil
}
m.stopSignalPump()
m.stateMutex.Lock()
m.state.SessionPath = string(sessionPath)
m.sessionPath = sessionPath
m.stateMutex.Unlock()
m.sessionObj = m.conn.Object(dbusDest, sessionPath)
if err := m.updateSessionState(); err != nil {
return err
}
m.signals = make(chan *dbus.Signal, 256)
return m.startSignalPump()
}
func (m *Manager) updateSessionState() error {
ctx := context.Background()
props, err := m.getSessionProperties(ctx)
if err != nil {
return err
}
m.stateMutex.Lock()
defer m.stateMutex.Unlock()
if v, ok := props["Active"]; ok {
if val, ok := v.Value().(bool); ok {
m.state.Active = val
}
}
if v, ok := props["IdleHint"]; ok {
if val, ok := v.Value().(bool); ok {
m.state.IdleHint = val
}
}
if v, ok := props["IdleSinceHint"]; ok {
if val, ok := v.Value().(uint64); ok {
m.state.IdleSinceHint = val
}
}
if v, ok := props["LockedHint"]; ok {
if val, ok := v.Value().(bool); ok {
m.state.LockedHint = val
m.state.Locked = val
}
}
if v, ok := props["Type"]; ok {
if val, ok := v.Value().(string); ok {
m.state.SessionType = val
}
}
if v, ok := props["Class"]; ok {
if val, ok := v.Value().(string); ok {
m.state.SessionClass = val
}
}
if v, ok := props["User"]; ok {
if userArr, ok := v.Value().([]interface{}); ok && len(userArr) >= 1 {
if uid, ok := userArr[0].(uint32); ok {
m.state.User = uid
}
}
}
if v, ok := props["Name"]; ok {
if val, ok := v.Value().(string); ok {
m.state.UserName = val
}
}
if v, ok := props["RemoteHost"]; ok {
if val, ok := v.Value().(string); ok {
m.state.RemoteHost = val
}
}
if v, ok := props["Service"]; ok {
if val, ok := v.Value().(string); ok {
m.state.Service = val
}
}
if v, ok := props["TTY"]; ok {
if val, ok := v.Value().(string); ok {
m.state.TTY = val
}
}
if v, ok := props["Display"]; ok {
if val, ok := v.Value().(string); ok {
m.state.Display = val
}
}
if v, ok := props["Remote"]; ok {
if val, ok := v.Value().(bool); ok {
m.state.Remote = val
}
}
if v, ok := props["Seat"]; ok {
if seatArr, ok := v.Value().([]interface{}); ok && len(seatArr) >= 1 {
if seatID, ok := seatArr[0].(string); ok {
m.state.Seat = seatID
}
}
}
if v, ok := props["VTNr"]; ok {
if val, ok := v.Value().(uint32); ok {
m.state.VTNr = val
}
}
return nil
}
func (m *Manager) getSessionProperties(ctx context.Context) (map[string]dbus.Variant, error) {
var props map[string]dbus.Variant
err := m.sessionObj.CallWithContext(ctx, dbusPropsInterface+".GetAll", 0, dbusSessionInterface).Store(&props)
if err != nil {
return nil, err
}
return props, nil
}
func (m *Manager) acquireSleepInhibitor() error {
if !m.sleepInhibitorEnabled.Load() {
return nil
}
m.inhibitMu.Lock()
defer m.inhibitMu.Unlock()
if m.inhibitFile != nil {
return nil
}
if m.managerObj == nil {
return fmt.Errorf("manager object not available")
}
file, err := m.inhibit("sleep", "DankMaterialShell", "Lock before suspend", "delay")
if err != nil {
return err
}
m.inhibitFile = file
return nil
}
func (m *Manager) inhibit(what, who, why, mode string) (*os.File, error) {
var fd dbus.UnixFD
err := m.managerObj.Call(dbusManagerInterface+".Inhibit", 0, what, who, why, mode).Store(&fd)
if err != nil {
return nil, err
}
return os.NewFile(uintptr(fd), "inhibit"), nil
}
func (m *Manager) releaseSleepInhibitor() {
m.inhibitMu.Lock()
f := m.inhibitFile
m.inhibitFile = nil
m.inhibitMu.Unlock()
if f != nil {
f.Close()
}
}
func (m *Manager) releaseForCycle(id uint64) {
if !m.inSleepCycle.Load() || m.sleepCycleID.Load() != id {
return
}
m.releaseSleepInhibitor()
}
func (m *Manager) initializeFallbackDelay() {
var maxDelayUSec uint64
err := m.managerObj.Call(
dbusPropsInterface+".Get",
0,
dbusManagerInterface,
"InhibitDelayMaxUSec",
).Store(&maxDelayUSec)
if err != nil {
m.fallbackDelay = 2 * time.Second
return
}
maxDelay := time.Duration(maxDelayUSec) * time.Microsecond
computed := (maxDelay * 8) / 10
if computed < 2*time.Second {
m.fallbackDelay = 2 * time.Second
} else if computed > 4*time.Second {
m.fallbackDelay = 4 * time.Second
} else {
m.fallbackDelay = computed
}
}
func (m *Manager) newLockerReadyCh() chan struct{} {
m.lockerReadyChMu.Lock()
defer m.lockerReadyChMu.Unlock()
m.lockerReadyCh = make(chan struct{})
return m.lockerReadyCh
}
func (m *Manager) signalLockerReady() {
m.lockerReadyChMu.Lock()
ch := m.lockerReadyCh
if ch != nil {
close(ch)
m.lockerReadyCh = nil
}
m.lockerReadyChMu.Unlock()
}
func (m *Manager) snapshotState() SessionState {
m.stateMutex.RLock()
defer m.stateMutex.RUnlock()
return *m.state
}
func stateChangedMeaningfully(old, new *SessionState) bool {
if old.Locked != new.Locked {
return true
}
if old.LockedHint != new.LockedHint {
return true
}
if old.Active != new.Active {
return true
}
if old.IdleHint != new.IdleHint {
return true
}
if old.PreparingForSleep != new.PreparingForSleep {
return true
}
return false
}
func (m *Manager) GetState() SessionState {
return m.snapshotState()
}
func (m *Manager) Subscribe(id string) chan SessionState {
ch := make(chan SessionState, 64)
m.subMutex.Lock()
m.subscribers[id] = ch
m.subMutex.Unlock()
return ch
}
func (m *Manager) Unsubscribe(id string) {
m.subMutex.Lock()
if ch, ok := m.subscribers[id]; ok {
close(ch)
delete(m.subscribers, id)
}
m.subMutex.Unlock()
}
func (m *Manager) notifier() {
defer m.notifierWg.Done()
const minGap = 100 * time.Millisecond
timer := time.NewTimer(minGap)
timer.Stop()
var pending bool
for {
select {
case <-m.stopChan:
timer.Stop()
return
case <-m.dirty:
if pending {
continue
}
pending = true
timer.Reset(minGap)
case <-timer.C:
if !pending {
continue
}
m.subMutex.RLock()
if len(m.subscribers) == 0 {
m.subMutex.RUnlock()
pending = false
continue
}
currentState := m.snapshotState()
if m.lastNotifiedState != nil && !stateChangedMeaningfully(m.lastNotifiedState, &currentState) {
m.subMutex.RUnlock()
pending = false
continue
}
for _, ch := range m.subscribers {
select {
case ch <- currentState:
default:
}
}
m.subMutex.RUnlock()
stateCopy := currentState
m.lastNotifiedState = &stateCopy
pending = false
}
}
}
func (m *Manager) notifySubscribers() {
select {
case m.dirty <- struct{}{}:
default:
}
}
func (m *Manager) startSignalPump() error {
m.conn.Signal(m.signals)
if err := m.conn.AddMatchSignal(
dbus.WithMatchObjectPath(m.sessionPath),
dbus.WithMatchInterface(dbusPropsInterface),
dbus.WithMatchMember("PropertiesChanged"),
); err != nil {
m.conn.RemoveSignal(m.signals)
return err
}
if err := m.conn.AddMatchSignal(
dbus.WithMatchObjectPath(m.sessionPath),
dbus.WithMatchInterface(dbusSessionInterface),
dbus.WithMatchMember("Lock"),
); err != nil {
m.conn.RemoveMatchSignal(
dbus.WithMatchObjectPath(m.sessionPath),
dbus.WithMatchInterface(dbusPropsInterface),
dbus.WithMatchMember("PropertiesChanged"),
)
m.conn.RemoveSignal(m.signals)
return err
}
if err := m.conn.AddMatchSignal(
dbus.WithMatchObjectPath(m.sessionPath),
dbus.WithMatchInterface(dbusSessionInterface),
dbus.WithMatchMember("Unlock"),
); err != nil {
m.conn.RemoveMatchSignal(
dbus.WithMatchObjectPath(m.sessionPath),
dbus.WithMatchInterface(dbusPropsInterface),
dbus.WithMatchMember("PropertiesChanged"),
)
m.conn.RemoveMatchSignal(
dbus.WithMatchObjectPath(m.sessionPath),
dbus.WithMatchInterface(dbusSessionInterface),
dbus.WithMatchMember("Lock"),
)
m.conn.RemoveSignal(m.signals)
return err
}
if err := m.conn.AddMatchSignal(
dbus.WithMatchObjectPath(dbus.ObjectPath(dbusPath)),
dbus.WithMatchInterface(dbusManagerInterface),
dbus.WithMatchMember("PrepareForSleep"),
); err != nil {
m.conn.RemoveMatchSignal(
dbus.WithMatchObjectPath(m.sessionPath),
dbus.WithMatchInterface(dbusPropsInterface),
dbus.WithMatchMember("PropertiesChanged"),
)
m.conn.RemoveMatchSignal(
dbus.WithMatchObjectPath(m.sessionPath),
dbus.WithMatchInterface(dbusSessionInterface),
dbus.WithMatchMember("Lock"),
)
m.conn.RemoveMatchSignal(
dbus.WithMatchObjectPath(m.sessionPath),
dbus.WithMatchInterface(dbusSessionInterface),
dbus.WithMatchMember("Unlock"),
)
m.conn.RemoveSignal(m.signals)
return err
}
if err := m.conn.AddMatchSignal(
dbus.WithMatchObjectPath("/org/freedesktop/DBus"),
dbus.WithMatchInterface("org.freedesktop.DBus"),
dbus.WithMatchMember("NameOwnerChanged"),
); err != nil {
m.conn.RemoveMatchSignal(
dbus.WithMatchObjectPath(m.sessionPath),
dbus.WithMatchInterface(dbusPropsInterface),
dbus.WithMatchMember("PropertiesChanged"),
)
m.conn.RemoveMatchSignal(
dbus.WithMatchObjectPath(m.sessionPath),
dbus.WithMatchInterface(dbusSessionInterface),
dbus.WithMatchMember("Lock"),
)
m.conn.RemoveMatchSignal(
dbus.WithMatchObjectPath(m.sessionPath),
dbus.WithMatchInterface(dbusSessionInterface),
dbus.WithMatchMember("Unlock"),
)
m.conn.RemoveMatchSignal(
dbus.WithMatchObjectPath(dbus.ObjectPath(dbusPath)),
dbus.WithMatchInterface(dbusManagerInterface),
dbus.WithMatchMember("PrepareForSleep"),
)
m.conn.RemoveSignal(m.signals)
return err
}
m.sigWG.Add(1)
go func() {
defer m.sigWG.Done()
for {
select {
case <-m.stopChan:
return
case sig, ok := <-m.signals:
if !ok {
return
}
if sig == nil {
continue
}
m.handleDBusSignal(sig)
}
}
}()
return nil
}
func (m *Manager) stopSignalPump() {
if m.conn == nil {
return
}
m.conn.RemoveMatchSignal(
dbus.WithMatchObjectPath(m.sessionPath),
dbus.WithMatchInterface(dbusPropsInterface),
dbus.WithMatchMember("PropertiesChanged"),
)
m.conn.RemoveMatchSignal(
dbus.WithMatchObjectPath(m.sessionPath),
dbus.WithMatchInterface(dbusSessionInterface),
dbus.WithMatchMember("Lock"),
)
m.conn.RemoveMatchSignal(
dbus.WithMatchObjectPath(m.sessionPath),
dbus.WithMatchInterface(dbusSessionInterface),
dbus.WithMatchMember("Unlock"),
)
m.conn.RemoveMatchSignal(
dbus.WithMatchObjectPath(dbus.ObjectPath(dbusPath)),
dbus.WithMatchInterface(dbusManagerInterface),
dbus.WithMatchMember("PrepareForSleep"),
)
m.conn.RemoveMatchSignal(
dbus.WithMatchObjectPath("/org/freedesktop/DBus"),
dbus.WithMatchInterface("org.freedesktop.DBus"),
dbus.WithMatchMember("NameOwnerChanged"),
)
m.conn.RemoveSignal(m.signals)
close(m.signals)
m.sigWG.Wait()
}
func (m *Manager) Close() {
close(m.stopChan)
m.notifierWg.Wait()
m.stopSignalPump()
m.releaseSleepInhibitor()
m.subMutex.Lock()
for _, ch := range m.subscribers {
close(ch)
}
m.subscribers = make(map[string]chan SessionState)
m.subMutex.Unlock()
if m.conn != nil {
m.conn.Close()
}
}

View File

@@ -0,0 +1,313 @@
package loginctl
import (
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestManager_GetState(t *testing.T) {
state := &SessionState{
SessionID: "1",
Locked: false,
Active: true,
IdleHint: false,
SessionType: "wayland",
SessionClass: "user",
UserName: "testuser",
}
manager := &Manager{
state: state,
stateMutex: sync.RWMutex{},
}
result := manager.GetState()
assert.Equal(t, "1", result.SessionID)
assert.False(t, result.Locked)
assert.True(t, result.Active)
assert.Equal(t, "wayland", result.SessionType)
assert.Equal(t, "testuser", result.UserName)
}
func TestManager_Subscribe(t *testing.T) {
manager := &Manager{
state: &SessionState{},
subscribers: make(map[string]chan SessionState),
subMutex: sync.RWMutex{},
}
ch := manager.Subscribe("test-client")
assert.NotNil(t, ch)
assert.Equal(t, 64, cap(ch))
manager.subMutex.RLock()
_, exists := manager.subscribers["test-client"]
manager.subMutex.RUnlock()
assert.True(t, exists)
}
func TestManager_Unsubscribe(t *testing.T) {
manager := &Manager{
state: &SessionState{},
subscribers: make(map[string]chan SessionState),
subMutex: sync.RWMutex{},
}
ch := manager.Subscribe("test-client")
manager.Unsubscribe("test-client")
_, ok := <-ch
assert.False(t, ok)
manager.subMutex.RLock()
_, exists := manager.subscribers["test-client"]
manager.subMutex.RUnlock()
assert.False(t, exists)
}
func TestManager_Unsubscribe_NonExistent(t *testing.T) {
manager := &Manager{
state: &SessionState{},
subscribers: make(map[string]chan SessionState),
subMutex: sync.RWMutex{},
}
// Unsubscribe a non-existent client should not panic
assert.NotPanics(t, func() {
manager.Unsubscribe("non-existent")
})
}
func TestManager_NotifySubscribers(t *testing.T) {
manager := &Manager{
state: &SessionState{
SessionID: "1",
Locked: false,
},
stateMutex: sync.RWMutex{},
subscribers: make(map[string]chan SessionState),
subMutex: sync.RWMutex{},
stopChan: make(chan struct{}),
dirty: make(chan struct{}, 1),
}
manager.notifierWg.Add(1)
go manager.notifier()
ch := make(chan SessionState, 10)
manager.subMutex.Lock()
manager.subscribers["test-client"] = ch
manager.subMutex.Unlock()
manager.notifySubscribers()
select {
case state := <-ch:
assert.Equal(t, "1", state.SessionID)
assert.False(t, state.Locked)
case <-time.After(200 * time.Millisecond):
t.Fatal("did not receive state update")
}
close(manager.stopChan)
manager.notifierWg.Wait()
}
func TestManager_NotifySubscribers_Debounce(t *testing.T) {
manager := &Manager{
state: &SessionState{
SessionID: "1",
Locked: false,
},
stateMutex: sync.RWMutex{},
subscribers: make(map[string]chan SessionState),
subMutex: sync.RWMutex{},
stopChan: make(chan struct{}),
dirty: make(chan struct{}, 1),
}
manager.notifierWg.Add(1)
go manager.notifier()
ch := make(chan SessionState, 10)
manager.subMutex.Lock()
manager.subscribers["test-client"] = ch
manager.subMutex.Unlock()
manager.notifySubscribers()
manager.notifySubscribers()
manager.notifySubscribers()
receivedCount := 0
timeout := time.After(200 * time.Millisecond)
for {
select {
case <-ch:
receivedCount++
case <-timeout:
assert.Equal(t, 1, receivedCount, "should receive exactly one debounced update")
close(manager.stopChan)
manager.notifierWg.Wait()
return
}
}
}
func TestManager_Close(t *testing.T) {
manager := &Manager{
state: &SessionState{},
stateMutex: sync.RWMutex{},
subscribers: make(map[string]chan SessionState),
subMutex: sync.RWMutex{},
stopChan: make(chan struct{}),
}
ch1 := make(chan SessionState, 1)
ch2 := make(chan SessionState, 1)
manager.subMutex.Lock()
manager.subscribers["client1"] = ch1
manager.subscribers["client2"] = ch2
manager.subMutex.Unlock()
manager.Close()
select {
case <-manager.stopChan:
case <-time.After(100 * time.Millisecond):
t.Fatal("stopChan not closed")
}
_, ok1 := <-ch1
_, ok2 := <-ch2
assert.False(t, ok1, "ch1 should be closed")
assert.False(t, ok2, "ch2 should be closed")
assert.Len(t, manager.subscribers, 0)
}
func TestManager_GetState_ThreadSafe(t *testing.T) {
manager := &Manager{
state: &SessionState{
SessionID: "1",
Locked: false,
Active: true,
},
stateMutex: sync.RWMutex{},
}
done := make(chan bool)
for i := 0; i < 10; i++ {
go func() {
state := manager.GetState()
assert.Equal(t, "1", state.SessionID)
assert.True(t, state.Active)
done <- true
}()
}
for i := 0; i < 10; i++ {
select {
case <-done:
case <-time.After(1 * time.Second):
t.Fatal("timeout waiting for goroutines")
}
}
}
func TestStateChangedMeaningfully(t *testing.T) {
tests := []struct {
name string
old *SessionState
new *SessionState
expected bool
}{
{
name: "no change",
old: &SessionState{Locked: false, Active: true, IdleHint: false},
new: &SessionState{Locked: false, Active: true, IdleHint: false},
expected: false,
},
{
name: "locked changed",
old: &SessionState{Locked: false, Active: true, IdleHint: false},
new: &SessionState{Locked: true, Active: true, IdleHint: false},
expected: true,
},
{
name: "active changed",
old: &SessionState{Locked: false, Active: true, IdleHint: false},
new: &SessionState{Locked: false, Active: false, IdleHint: false},
expected: true,
},
{
name: "idle hint changed",
old: &SessionState{Locked: false, Active: true, IdleHint: false},
new: &SessionState{Locked: false, Active: true, IdleHint: true},
expected: true,
},
{
name: "locked hint changed",
old: &SessionState{Locked: false, Active: true, LockedHint: false},
new: &SessionState{Locked: false, Active: true, LockedHint: true},
expected: true,
},
{
name: "preparing for sleep changed",
old: &SessionState{Locked: false, Active: true, PreparingForSleep: false},
new: &SessionState{Locked: false, Active: true, PreparingForSleep: true},
expected: true,
},
{
name: "non-meaningful change (username)",
old: &SessionState{Locked: false, Active: true, UserName: "user1"},
new: &SessionState{Locked: false, Active: true, UserName: "user2"},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := stateChangedMeaningfully(tt.old, tt.new)
assert.Equal(t, tt.expected, result)
})
}
}
func TestManager_SnapshotState(t *testing.T) {
manager := &Manager{
state: &SessionState{
SessionID: "1",
Locked: false,
Active: true,
UserName: "testuser",
},
stateMutex: sync.RWMutex{},
}
snapshot := manager.snapshotState()
assert.Equal(t, "1", snapshot.SessionID)
assert.False(t, snapshot.Locked)
assert.True(t, snapshot.Active)
assert.Equal(t, "testuser", snapshot.UserName)
snapshot.Locked = true
assert.False(t, manager.state.Locked)
}
func TestNewManager(t *testing.T) {
t.Run("attempts to create manager", func(t *testing.T) {
manager, err := NewManager()
if err != nil {
assert.Nil(t, manager)
} else {
assert.NotNil(t, manager)
assert.NotNil(t, manager.state)
assert.NotNil(t, manager.subscribers)
assert.NotNil(t, manager.stopChan)
manager.Close()
}
})
}

View File

@@ -0,0 +1,157 @@
package loginctl
import (
"time"
"github.com/godbus/dbus/v5"
)
func (m *Manager) handleDBusSignal(sig *dbus.Signal) {
switch sig.Name {
case dbusSessionInterface + ".Lock":
m.stateMutex.Lock()
m.state.Locked = true
m.state.LockedHint = true
m.stateMutex.Unlock()
m.notifySubscribers()
if m.sleepInhibitorEnabled.Load() && m.inSleepCycle.Load() {
id := m.sleepCycleID.Load()
m.lockTimerMu.Lock()
if m.lockTimer != nil {
m.lockTimer.Stop()
}
m.lockTimer = time.AfterFunc(m.fallbackDelay, func() {
m.releaseForCycle(id)
})
m.lockTimerMu.Unlock()
}
case dbusSessionInterface + ".Unlock":
m.stateMutex.Lock()
m.state.Locked = false
m.state.LockedHint = false
m.stateMutex.Unlock()
m.notifySubscribers()
// Cancel the lock timer if it's still running
m.lockTimerMu.Lock()
if m.lockTimer != nil {
m.lockTimer.Stop()
m.lockTimer = nil
}
m.lockTimerMu.Unlock()
// Re-acquire the sleep inhibitor (acquireSleepInhibitor checks the enabled flag)
m.acquireSleepInhibitor()
case dbusManagerInterface + ".PrepareForSleep":
if len(sig.Body) == 0 {
return
}
preparing, _ := sig.Body[0].(bool)
if preparing {
cycleID := m.sleepCycleID.Add(1)
m.inSleepCycle.Store(true)
if m.lockBeforeSuspend.Load() {
m.Lock()
}
readyCh := m.newLockerReadyCh()
go func(id uint64, ch <-chan struct{}) {
<-ch
if m.inSleepCycle.Load() && m.sleepCycleID.Load() == id {
m.releaseSleepInhibitor()
}
}(cycleID, readyCh)
} else {
m.inSleepCycle.Store(false)
m.signalLockerReady()
m.refreshSessionBinding()
m.acquireSleepInhibitor()
}
m.stateMutex.Lock()
m.state.PreparingForSleep = preparing
m.stateMutex.Unlock()
m.notifySubscribers()
case dbusPropsInterface + ".PropertiesChanged":
m.handlePropertiesChanged(sig)
case "org.freedesktop.DBus.NameOwnerChanged":
if len(sig.Body) == 3 {
name, _ := sig.Body[0].(string)
oldOwner, _ := sig.Body[1].(string)
newOwner, _ := sig.Body[2].(string)
if name == dbusDest && oldOwner != "" && newOwner != "" {
m.updateSessionState()
if !m.inSleepCycle.Load() {
m.acquireSleepInhibitor()
}
m.notifySubscribers()
}
}
}
}
func (m *Manager) handlePropertiesChanged(sig *dbus.Signal) {
if len(sig.Body) < 2 {
return
}
iface, ok := sig.Body[0].(string)
if !ok || iface != dbusSessionInterface {
return
}
changes, ok := sig.Body[1].(map[string]dbus.Variant)
if !ok {
return
}
var needsUpdate bool
for key, variant := range changes {
switch key {
case "Active":
if val, ok := variant.Value().(bool); ok {
m.stateMutex.Lock()
m.state.Active = val
m.stateMutex.Unlock()
needsUpdate = true
}
case "IdleHint":
if val, ok := variant.Value().(bool); ok {
m.stateMutex.Lock()
m.state.IdleHint = val
m.stateMutex.Unlock()
needsUpdate = true
}
case "IdleSinceHint":
if val, ok := variant.Value().(uint64); ok {
m.stateMutex.Lock()
m.state.IdleSinceHint = val
m.stateMutex.Unlock()
needsUpdate = true
}
case "LockedHint":
if val, ok := variant.Value().(bool); ok {
m.stateMutex.Lock()
m.state.LockedHint = val
m.state.Locked = val
m.stateMutex.Unlock()
needsUpdate = true
}
}
}
if needsUpdate {
m.notifySubscribers()
}
}

View File

@@ -0,0 +1,322 @@
package loginctl
import (
"sync"
"testing"
"github.com/godbus/dbus/v5"
"github.com/stretchr/testify/assert"
)
func TestManager_HandleDBusSignal_Lock(t *testing.T) {
manager := &Manager{
state: &SessionState{
Locked: false,
LockedHint: false,
},
stateMutex: sync.RWMutex{},
subscribers: make(map[string]chan SessionState),
subMutex: sync.RWMutex{},
dirty: make(chan struct{}, 1),
}
sig := &dbus.Signal{
Name: "org.freedesktop.login1.Session.Lock",
}
manager.handleDBusSignal(sig)
manager.stateMutex.RLock()
defer manager.stateMutex.RUnlock()
assert.True(t, manager.state.Locked)
assert.True(t, manager.state.LockedHint)
}
func TestManager_HandleDBusSignal_Unlock(t *testing.T) {
manager := &Manager{
state: &SessionState{
Locked: true,
LockedHint: true,
},
stateMutex: sync.RWMutex{},
subscribers: make(map[string]chan SessionState),
subMutex: sync.RWMutex{},
dirty: make(chan struct{}, 1),
}
sig := &dbus.Signal{
Name: "org.freedesktop.login1.Session.Unlock",
}
manager.handleDBusSignal(sig)
manager.stateMutex.RLock()
defer manager.stateMutex.RUnlock()
assert.False(t, manager.state.Locked)
assert.False(t, manager.state.LockedHint)
}
func TestManager_HandleDBusSignal_PrepareForSleep(t *testing.T) {
t.Run("preparing for sleep - true", func(t *testing.T) {
manager := &Manager{
state: &SessionState{
PreparingForSleep: false,
},
stateMutex: sync.RWMutex{},
subscribers: make(map[string]chan SessionState),
subMutex: sync.RWMutex{},
dirty: make(chan struct{}, 1),
}
sig := &dbus.Signal{
Name: "org.freedesktop.login1.Manager.PrepareForSleep",
Body: []interface{}{true},
}
manager.handleDBusSignal(sig)
manager.stateMutex.RLock()
defer manager.stateMutex.RUnlock()
assert.True(t, manager.state.PreparingForSleep)
})
t.Run("preparing for sleep - false", func(t *testing.T) {
manager := &Manager{
state: &SessionState{
PreparingForSleep: true,
},
stateMutex: sync.RWMutex{},
subscribers: make(map[string]chan SessionState),
subMutex: sync.RWMutex{},
dirty: make(chan struct{}, 1),
}
sig := &dbus.Signal{
Name: "org.freedesktop.login1.Manager.PrepareForSleep",
Body: []interface{}{false},
}
manager.handleDBusSignal(sig)
manager.stateMutex.RLock()
defer manager.stateMutex.RUnlock()
assert.False(t, manager.state.PreparingForSleep)
})
t.Run("empty body", func(t *testing.T) {
manager := &Manager{
state: &SessionState{
PreparingForSleep: false,
},
stateMutex: sync.RWMutex{},
subscribers: make(map[string]chan SessionState),
subMutex: sync.RWMutex{},
dirty: make(chan struct{}, 1),
}
sig := &dbus.Signal{
Name: "org.freedesktop.login1.Manager.PrepareForSleep",
Body: []interface{}{},
}
manager.handleDBusSignal(sig)
manager.stateMutex.RLock()
defer manager.stateMutex.RUnlock()
assert.False(t, manager.state.PreparingForSleep)
})
}
func TestManager_HandlePropertiesChanged(t *testing.T) {
t.Run("active property changed", func(t *testing.T) {
manager := &Manager{
state: &SessionState{
Active: false,
},
stateMutex: sync.RWMutex{},
subscribers: make(map[string]chan SessionState),
subMutex: sync.RWMutex{},
dirty: make(chan struct{}, 1),
}
sig := &dbus.Signal{
Name: "org.freedesktop.DBus.Properties.PropertiesChanged",
Body: []interface{}{
"org.freedesktop.login1.Session",
map[string]dbus.Variant{
"Active": dbus.MakeVariant(true),
},
},
}
manager.handlePropertiesChanged(sig)
manager.stateMutex.RLock()
defer manager.stateMutex.RUnlock()
assert.True(t, manager.state.Active)
})
t.Run("idle hint property changed", func(t *testing.T) {
manager := &Manager{
state: &SessionState{
IdleHint: false,
},
stateMutex: sync.RWMutex{},
subscribers: make(map[string]chan SessionState),
subMutex: sync.RWMutex{},
dirty: make(chan struct{}, 1),
}
sig := &dbus.Signal{
Name: "org.freedesktop.DBus.Properties.PropertiesChanged",
Body: []interface{}{
"org.freedesktop.login1.Session",
map[string]dbus.Variant{
"IdleHint": dbus.MakeVariant(true),
},
},
}
manager.handlePropertiesChanged(sig)
manager.stateMutex.RLock()
defer manager.stateMutex.RUnlock()
assert.True(t, manager.state.IdleHint)
})
t.Run("idle since hint property changed", func(t *testing.T) {
manager := &Manager{
state: &SessionState{
IdleSinceHint: 0,
},
stateMutex: sync.RWMutex{},
subscribers: make(map[string]chan SessionState),
subMutex: sync.RWMutex{},
dirty: make(chan struct{}, 1),
}
sig := &dbus.Signal{
Name: "org.freedesktop.DBus.Properties.PropertiesChanged",
Body: []interface{}{
"org.freedesktop.login1.Session",
map[string]dbus.Variant{
"IdleSinceHint": dbus.MakeVariant(uint64(123456789)),
},
},
}
manager.handlePropertiesChanged(sig)
manager.stateMutex.RLock()
defer manager.stateMutex.RUnlock()
assert.Equal(t, uint64(123456789), manager.state.IdleSinceHint)
})
t.Run("locked hint property changed", func(t *testing.T) {
manager := &Manager{
state: &SessionState{
LockedHint: false,
Locked: false,
},
stateMutex: sync.RWMutex{},
subscribers: make(map[string]chan SessionState),
subMutex: sync.RWMutex{},
dirty: make(chan struct{}, 1),
}
sig := &dbus.Signal{
Name: "org.freedesktop.DBus.Properties.PropertiesChanged",
Body: []interface{}{
"org.freedesktop.login1.Session",
map[string]dbus.Variant{
"LockedHint": dbus.MakeVariant(true),
},
},
}
manager.handlePropertiesChanged(sig)
manager.stateMutex.RLock()
defer manager.stateMutex.RUnlock()
assert.True(t, manager.state.LockedHint)
assert.True(t, manager.state.Locked)
})
t.Run("wrong interface", func(t *testing.T) {
manager := &Manager{
state: &SessionState{
Active: false,
},
stateMutex: sync.RWMutex{},
subscribers: make(map[string]chan SessionState),
subMutex: sync.RWMutex{},
dirty: make(chan struct{}, 1),
}
sig := &dbus.Signal{
Name: "org.freedesktop.DBus.Properties.PropertiesChanged",
Body: []interface{}{
"org.freedesktop.SomeOtherInterface",
map[string]dbus.Variant{
"Active": dbus.MakeVariant(true),
},
},
}
manager.handlePropertiesChanged(sig)
manager.stateMutex.RLock()
defer manager.stateMutex.RUnlock()
assert.False(t, manager.state.Active)
})
t.Run("empty body", func(t *testing.T) {
manager := &Manager{
state: &SessionState{},
stateMutex: sync.RWMutex{},
subscribers: make(map[string]chan SessionState),
subMutex: sync.RWMutex{},
dirty: make(chan struct{}, 1),
}
sig := &dbus.Signal{
Name: "org.freedesktop.DBus.Properties.PropertiesChanged",
Body: []interface{}{},
}
assert.NotPanics(t, func() {
manager.handlePropertiesChanged(sig)
})
})
t.Run("multiple properties changed", func(t *testing.T) {
manager := &Manager{
state: &SessionState{
Active: false,
IdleHint: false,
},
stateMutex: sync.RWMutex{},
subscribers: make(map[string]chan SessionState),
subMutex: sync.RWMutex{},
dirty: make(chan struct{}, 1),
}
sig := &dbus.Signal{
Name: "org.freedesktop.DBus.Properties.PropertiesChanged",
Body: []interface{}{
"org.freedesktop.login1.Session",
map[string]dbus.Variant{
"Active": dbus.MakeVariant(true),
"IdleHint": dbus.MakeVariant(true),
},
},
}
manager.handlePropertiesChanged(sig)
manager.stateMutex.RLock()
defer manager.stateMutex.RUnlock()
assert.True(t, manager.state.Active)
assert.True(t, manager.state.IdleHint)
})
}

View File

@@ -0,0 +1,76 @@
package loginctl
import (
"os"
"sync"
"sync/atomic"
"time"
"github.com/godbus/dbus/v5"
)
type SessionState struct {
SessionID string `json:"sessionId"`
SessionPath string `json:"sessionPath"`
Locked bool `json:"locked"`
Active bool `json:"active"`
IdleHint bool `json:"idleHint"`
IdleSinceHint uint64 `json:"idleSinceHint"`
LockedHint bool `json:"lockedHint"`
SessionType string `json:"sessionType"`
SessionClass string `json:"sessionClass"`
User uint32 `json:"user"`
UserName string `json:"userName"`
RemoteHost string `json:"remoteHost"`
Service string `json:"service"`
TTY string `json:"tty"`
Display string `json:"display"`
Remote bool `json:"remote"`
Seat string `json:"seat"`
VTNr uint32 `json:"vtnr"`
PreparingForSleep bool `json:"preparingForSleep"`
}
type EventType string
const (
EventStateChanged EventType = "state_changed"
EventLock EventType = "lock"
EventUnlock EventType = "unlock"
EventPrepareForSleep EventType = "prepare_for_sleep"
EventIdleHintChanged EventType = "idle_hint_changed"
EventLockedHintChanged EventType = "locked_hint_changed"
)
type SessionEvent struct {
Type EventType `json:"type"`
Data SessionState `json:"data"`
}
type Manager struct {
state *SessionState
stateMutex sync.RWMutex
subscribers map[string]chan SessionState
subMutex sync.RWMutex
stopChan chan struct{}
conn *dbus.Conn
sessionPath dbus.ObjectPath
managerObj dbus.BusObject
sessionObj dbus.BusObject
dirty chan struct{}
notifierWg sync.WaitGroup
lastNotifiedState *SessionState
signals chan *dbus.Signal
sigWG sync.WaitGroup
inhibitMu sync.Mutex
inhibitFile *os.File
lockBeforeSuspend atomic.Bool
inSleepCycle atomic.Bool
sleepCycleID atomic.Uint64
lockerReadyChMu sync.Mutex
lockerReadyCh chan struct{}
lockTimerMu sync.Mutex
lockTimer *time.Timer
sleepInhibitorEnabled atomic.Bool
fallbackDelay time.Duration
}

View File

@@ -0,0 +1,63 @@
package loginctl
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestEventType_Constants(t *testing.T) {
assert.Equal(t, EventType("state_changed"), EventStateChanged)
assert.Equal(t, EventType("lock"), EventLock)
assert.Equal(t, EventType("unlock"), EventUnlock)
assert.Equal(t, EventType("prepare_for_sleep"), EventPrepareForSleep)
assert.Equal(t, EventType("idle_hint_changed"), EventIdleHintChanged)
assert.Equal(t, EventType("locked_hint_changed"), EventLockedHintChanged)
}
func TestSessionState_Struct(t *testing.T) {
state := SessionState{
SessionID: "1",
SessionPath: "/org/freedesktop/login1/session/_31",
Locked: false,
Active: true,
IdleHint: false,
IdleSinceHint: 0,
LockedHint: false,
SessionType: "wayland",
SessionClass: "user",
User: 1000,
UserName: "testuser",
RemoteHost: "",
Service: "gdm-password",
TTY: "tty2",
Display: ":1",
Remote: false,
Seat: "seat0",
VTNr: 2,
PreparingForSleep: false,
}
assert.Equal(t, "1", state.SessionID)
assert.True(t, state.Active)
assert.False(t, state.Locked)
assert.Equal(t, "wayland", state.SessionType)
assert.Equal(t, uint32(1000), state.User)
assert.Equal(t, "testuser", state.UserName)
}
func TestSessionEvent_Struct(t *testing.T) {
state := SessionState{
SessionID: "1",
Locked: true,
}
event := SessionEvent{
Type: EventLock,
Data: state,
}
assert.Equal(t, EventLock, event.Type)
assert.Equal(t, "1", event.Data.SessionID)
assert.True(t, event.Data.Locked)
}

View File

@@ -0,0 +1,31 @@
package models
import (
"encoding/json"
"net"
"github.com/AvengeMedia/danklinux/internal/log"
)
type Request struct {
ID int `json:"id,omitempty"`
Method string `json:"method"`
Params map[string]interface{} `json:"params,omitempty"`
}
type Response[T any] struct {
ID int `json:"id,omitempty"`
Result *T `json:"result,omitempty"`
Error string `json:"error,omitempty"`
}
func RespondError(conn net.Conn, id int, errMsg string) {
log.Errorf("DMS API Error: id=%d error=%s", id, errMsg)
resp := Response[any]{ID: id, Error: errMsg}
json.NewEncoder(conn).Encode(resp)
}
func Respond[T any](conn net.Conn, id int, result T) {
resp := Response[T]{ID: id, Result: &result}
json.NewEncoder(conn).Encode(resp)
}

View File

@@ -0,0 +1,552 @@
# NetworkManager API Documentation
## Overview
The network manager API provides methods for managing WiFi connections, monitoring network state, and handling credential prompts through NetworkManager. Communication occurs over a message-based protocol (websocket, IPC, etc.) with event subscriptions for state updates.
## API Methods
### network.wifi.connect
Initiate a WiFi connection.
**Request:**
```json
{
"method": "network.wifi.connect",
"params": {
"ssid": "NetworkName",
"password": "optional-password",
"interactive": true
}
}
```
**Parameters:**
- `ssid` (string, required): Network SSID
- `password` (string, optional): Pre-shared key for WPA/WPA2/WPA3 networks
- `interactive` (boolean, optional): Enable credential prompting if authentication fails or password is missing. Automatically set to `true` when connecting to secured networks without providing a password.
**Response:**
```json
{
"success": true,
"message": "connecting"
}
```
**Behavior:**
- Returns immediately; connection happens asynchronously
- State updates delivered via `network` service subscription
- Credential prompts delivered via `network.credentials` service subscription
### network.credentials.submit
Submit credentials in response to a prompt.
**Request:**
```json
{
"method": "network.credentials.submit",
"params": {
"token": "correlation-token",
"secrets": {
"psk": "password"
},
"save": true
}
}
```
**Parameters:**
- `token` (string, required): Token from credential prompt
- `secrets` (object, required): Key-value map of credential fields
- `save` (boolean, optional): Whether to persist credentials (default: false)
**Common secret fields:**
- `psk`: Pre-shared key for WPA2/WPA3 personal networks
- `identity`: Username for 802.1X enterprise networks
- `password`: Password for 802.1X enterprise networks
### network.credentials.cancel
Cancel a credential prompt.
**Request:**
```json
{
"method": "network.credentials.cancel",
"params": {
"token": "correlation-token"
}
}
```
## Event Subscriptions
### Subscribing to Events
Subscribe to receive network state updates and credential prompts:
```json
{
"method": "subscribe",
"params": {
"services": ["network", "network.credentials"]
}
}
```
Both services are required for full connection handling. Missing `network.credentials` means credential prompts won't be received.
### network Service Events
State updates are sent whenever network configuration changes:
```json
{
"service": "network",
"data": {
"networkStatus": "wifi",
"isConnecting": false,
"connectingSSID": "",
"wifiConnected": true,
"wifiSSID": "MyNetwork",
"wifiIP": "192.168.1.100",
"lastError": ""
}
}
```
**State fields:**
- `networkStatus`: Current connection type (`wifi`, `ethernet`, `disconnected`)
- `isConnecting`: Whether a connection attempt is in progress
- `connectingSSID`: SSID being connected to (empty when idle)
- `wifiConnected`: Whether associated with an access point
- `wifiSSID`: Currently connected network name
- `wifiIP`: Assigned IP address (empty until DHCP completes)
- `lastError`: Error message from last failed connection attempt
### network.credentials Service Events
Credential prompts are sent when authentication is required:
```json
{
"service": "network.credentials",
"data": {
"token": "unique-prompt-id",
"ssid": "NetworkName",
"setting": "802-11-wireless-security",
"fields": ["psk"],
"hints": ["wpa3", "sae"],
"reason": "Credentials required"
}
}
```
**Prompt fields:**
- `token`: Unique identifier for this prompt (use in submit/cancel)
- `ssid`: Network requesting credentials
- `setting`: Authentication type (`802-11-wireless-security` for personal WiFi, `802-1x` for enterprise)
- `fields`: Array of required credential field names
- `hints`: Additional context about the network type
- `reason`: Human-readable explanation (e.g., "Previous password was incorrect")
## Connection Flow
### Typical Timeline
```
T+0ms Call network.wifi.connect
T+10ms Receive {"success": true, "message": "connecting"}
T+100ms State update: isConnecting=true, connectingSSID="Network"
T+500ms Credential prompt (if needed)
T+1000ms Submit credentials
T+3000ms State update: wifiConnected=true, wifiIP="192.168.x.x"
```
### State Machine
```
IDLE
|
| network.wifi.connect
v
CONNECTING (isConnecting=true, connectingSSID set)
|
+-- Needs credentials
| |
| v
| PROMPTING (credential prompt event)
| |
| | network.credentials.submit
| v
| back to CONNECTING
|
+-- Success
| |
| v
| CONNECTED (wifiConnected=true, wifiIP set, isConnecting=false)
|
+-- Failure
|
v
ERROR (isConnecting=false, !wifiConnected, lastError set)
```
## Connection Success Detection
A connection is successful when all of the following are true:
1. `wifiConnected` is `true`
2. `wifiIP` is set and non-empty
3. `wifiSSID` matches the target network
4. `isConnecting` is `false`
Do not rely on `wifiConnected` alone - the device may be associated with an access point but not have an IP address yet.
**Example:**
```javascript
function isConnectionComplete(state, targetSSID) {
return state.wifiConnected &&
state.wifiIP &&
state.wifiIP !== "" &&
state.wifiSSID === targetSSID &&
!state.isConnecting;
}
```
## Error Handling
### Error Detection
Errors occur when a connection attempt stops without success:
```javascript
function checkForFailure(state, wasConnecting, targetSSID) {
// Was connecting, now idle, but not connected
if (wasConnecting &&
!state.isConnecting &&
state.connectingSSID === "" &&
!state.wifiConnected) {
return state.lastError || "Connection failed";
}
return null;
}
```
### Common Error Scenarios
#### Wrong Password
**Detection methods:**
1. Quick failure (< 3 seconds from start)
2. `lastError` contains "password", "auth", or "secrets"
3. Second credential prompt with `reason: "Previous password was incorrect"`
**Handling:**
```javascript
if (prompt.reason === "Previous password was incorrect") {
// Show error, clear password field, re-focus input
}
```
#### Network Out of Range
**Detection:**
- `lastError` contains "not-found" or "connection-attempt-failed"
#### Connection Timeout
**Detection:**
- `isConnecting` remains true for > 30 seconds
**Implementation:**
```javascript
let timeout = setTimeout(() => {
if (currentState.isConnecting) {
handleTimeout();
}
}, 30000);
```
#### DHCP Failure
**Detection:**
- `wifiConnected` is true
- `wifiIP` is empty after 15+ seconds
### Error Message Translation
Map technical errors to user-friendly messages:
| lastError value | Meaning | User message |
|----------------|---------|--------------|
| `secrets-required` | Password needed | "Please enter password" |
| `authentication-failed` | Wrong password | "Incorrect password" |
| `connection-removed` | Profile deleted | "Network configuration removed" |
| `connection-attempt-failed` | Generic failure | "Failed to connect" |
| `network-not-found` | Out of range | "Network not found" |
| `(timeout)` | Timeout | "Connection timed out" |
## Credential Handling
### Secret Agent Architecture
The credential system uses a broker pattern:
```
NetworkManager -> SecretAgent -> PromptBroker -> UI -> User
^
|
User Response
|
NetworkManager <- SecretAgent <- PromptBroker <- UI
```
### Implementing a Broker
```go
type CustomBroker struct {
ui UIInterface
pending map[string]chan network.PromptReply
}
func (b *CustomBroker) Ask(ctx context.Context, req network.PromptRequest) (string, error) {
token := generateToken()
b.pending[token] = make(chan network.PromptReply, 1)
// Send to UI
b.ui.ShowCredentialPrompt(token, req)
return token, nil
}
func (b *CustomBroker) Wait(ctx context.Context, token string) (network.PromptReply, error) {
select {
case <-ctx.Done():
return network.PromptReply{}, errors.New("timeout")
case reply := <-b.pending[token]:
return reply, nil
}
}
func (b *CustomBroker) Resolve(token string, reply network.PromptReply) error {
if ch, ok := b.pending[token]; ok {
ch <- reply
close(ch)
delete(b.pending, token)
}
return nil
}
```
### Credential Field Types
**Personal WiFi (802-11-wireless-security):**
- Fields: `["psk"]`
- UI: Single password input
**Enterprise WiFi (802-1x):**
- Fields: `["identity", "password"]`
- UI: Username and password inputs
### Building Secrets Object
```javascript
function buildSecrets(setting, fields, formData) {
let secrets = {};
if (setting === "802-11-wireless-security") {
secrets.psk = formData.password;
} else if (setting === "802-1x") {
secrets.identity = formData.username;
secrets.password = formData.password;
}
return secrets;
}
```
## Best Practices
### Track Target Network
Always store which network you're connecting to:
```javascript
let targetSSID = null;
function connect(ssid) {
targetSSID = ssid;
// send request
}
function onStateUpdate(state) {
if (!targetSSID) return;
if (state.wifiSSID === targetSSID && state.wifiConnected && state.wifiIP) {
// Success for the network we care about
targetSSID = null;
}
}
```
### Implement Timeouts
Never wait indefinitely for a connection:
```javascript
const CONNECTION_TIMEOUT = 30000; // 30 seconds
const DHCP_TIMEOUT = 15000; // 15 seconds
let timer = setTimeout(() => {
if (stillConnecting) {
handleTimeout();
}
}, CONNECTION_TIMEOUT);
```
### Handle Credential Re-prompts
Wrong passwords trigger a second prompt:
```javascript
function onCredentialPrompt(prompt) {
if (prompt.reason.includes("incorrect")) {
// Show error, but keep dialog open
showError("Wrong password");
clearPasswordField();
} else {
// First time prompt
showDialog(prompt);
}
}
```
### Clean Up State
Reset tracking variables on success, failure, or cancellation:
```javascript
function cleanup() {
clearTimeout(timer);
targetSSID = null;
closeDialogs();
}
```
### Subscribe to Both Services
Missing `network.credentials` means prompts won't arrive:
```javascript
// Correct
services: ["network", "network.credentials"]
// Wrong - will miss credential prompts
services: ["network"]
```
## Testing
### Connection Test Checklist
- [ ] Connect to open network
- [ ] Connect to WPA2 network with password provided
- [ ] Connect to WPA2 network without password (triggers prompt)
- [ ] Enter wrong password (verify error and re-prompt)
- [ ] Cancel credential prompt
- [ ] Connection timeout after 30 seconds
- [ ] DHCP timeout detection
- [ ] Network out of range
- [ ] Reconnect to already-configured network
### Verifying Secret Agent Setup
Check connection profile flags:
```bash
nmcli connection show "NetworkName" | grep flags
# Should show: 802-11-wireless-security.psk-flags: 1 (agent-owned)
```
Check agent registration in logs:
```
INFO: Registered with NetworkManager as secret agent
```
## Security
- Never log credential values (passwords, PSKs)
- Clear password fields when dialogs close
- Implement prompt timeouts (default: 2 minutes)
- Validate user input before submission
- Use secure channels for credential transmission
## Troubleshooting
### Credential prompt doesn't appear
**Check:**
- Subscribed to both `network` and `network.credentials`
- Connection has `interactive: true`
- Secret flags set to AGENT_OWNED (value: 1)
- Broker registered successfully
### Connection succeeds without prompting
**Cause:** NetworkManager found saved credentials
**Solution:** Delete existing connection first, or use different credentials
### State updates seem delayed
**Expected behavior:** State changes occur in rapid succession during connection
**Solution:** Debounce UI updates; only act on final state
### Multiple rapid credential prompts
**Cause:** Connection profile has incorrect flags or conflicting agents
**Solution:**
- Check only one agent is running
- Verify psk-flags value
- Check NetworkManager logs for agent conflicts
## Data Structures Reference
### PromptRequest
```go
type PromptRequest struct {
SSID string `json:"ssid"`
SettingName string `json:"setting"`
Fields []string `json:"fields"`
Hints []string `json:"hints"`
Reason string `json:"reason"`
}
```
### PromptReply
```go
type PromptReply struct {
Secrets map[string]string `json:"secrets"`
Save bool `json:"save"`
Cancel bool `json:"cancel"`
}
```
### NetworkState
```go
type NetworkState struct {
NetworkStatus string `json:"networkStatus"`
IsConnecting bool `json:"isConnecting"`
ConnectingSSID string `json:"connectingSSID"`
WifiConnected bool `json:"wifiConnected"`
WifiSSID string `json:"wifiSSID"`
WifiIP string `json:"wifiIP"`
LastError string `json:"lastError"`
}
```

View File

@@ -0,0 +1,306 @@
package network
import (
"context"
"errors"
"fmt"
"time"
"github.com/AvengeMedia/danklinux/internal/errdefs"
"github.com/godbus/dbus/v5"
)
const (
iwdAgentManagerPath = "/net/connman/iwd"
iwdAgentManagerIface = "net.connman.iwd.AgentManager"
iwdAgentInterface = "net.connman.iwd.Agent"
iwdAgentObjectPath = "/com/danklinux/iwdagent"
)
type ConnectionStateChecker interface {
IsConnectingTo(ssid string) bool
}
type IWDAgent struct {
conn *dbus.Conn
objPath dbus.ObjectPath
prompts PromptBroker
onUserCanceled func()
onPromptRetry func(ssid string)
lastRequestSSID string
stateChecker ConnectionStateChecker
}
const iwdAgentIntrospectXML = `
<node>
<interface name="net.connman.iwd.Agent">
<method name="Release">
<annotation name="org.freedesktop.DBus.Method.NoReply" value="true"/>
</method>
<method name="RequestPassphrase">
<arg type="o" name="network" direction="in"/>
<arg type="s" name="passphrase" direction="out"/>
</method>
<method name="RequestPrivateKeyPassphrase">
<arg type="o" name="network" direction="in"/>
<arg type="s" name="passphrase" direction="out"/>
</method>
<method name="RequestUserNameAndPassword">
<arg type="o" name="network" direction="in"/>
<arg type="s" name="username" direction="out"/>
<arg type="s" name="password" direction="out"/>
</method>
<method name="RequestUserPassword">
<arg type="o" name="network" direction="in"/>
<arg type="s" name="user" direction="in"/>
<arg type="s" name="password" direction="out"/>
</method>
<method name="Cancel">
<arg type="s" name="reason" direction="in"/>
<annotation name="org.freedesktop.DBus.Method.NoReply" value="true"/>
</method>
</interface>
</node>`
func NewIWDAgent(conn *dbus.Conn, prompts PromptBroker) (*IWDAgent, error) {
if conn == nil {
return nil, fmt.Errorf("dbus connection is nil")
}
agent := &IWDAgent{
conn: conn,
objPath: dbus.ObjectPath(iwdAgentObjectPath),
prompts: prompts,
}
if err := conn.Export(agent, agent.objPath, iwdAgentInterface); err != nil {
return nil, fmt.Errorf("failed to export IWD agent: %w", err)
}
if err := conn.Export(agent, agent.objPath, "org.freedesktop.DBus.Introspectable"); err != nil {
return nil, fmt.Errorf("failed to export introspection: %w", err)
}
mgr := conn.Object("net.connman.iwd", dbus.ObjectPath(iwdAgentManagerPath))
call := mgr.Call(iwdAgentManagerIface+".RegisterAgent", 0, agent.objPath)
if call.Err != nil {
return nil, fmt.Errorf("failed to register agent with iwd: %w", call.Err)
}
return agent, nil
}
func (a *IWDAgent) Close() {
if a.conn != nil {
mgr := a.conn.Object("net.connman.iwd", dbus.ObjectPath(iwdAgentManagerPath))
mgr.Call(iwdAgentManagerIface+".UnregisterAgent", 0, a.objPath)
}
}
func (a *IWDAgent) SetStateChecker(checker ConnectionStateChecker) {
a.stateChecker = checker
}
func (a *IWDAgent) getNetworkName(networkPath dbus.ObjectPath) string {
netObj := a.conn.Object("net.connman.iwd", networkPath)
nameVar, err := netObj.GetProperty("net.connman.iwd.Network.Name")
if err == nil {
if name, ok := nameVar.Value().(string); ok {
return name
}
}
return string(networkPath)
}
func (a *IWDAgent) RequestPassphrase(network dbus.ObjectPath) (string, *dbus.Error) {
ssid := a.getNetworkName(network)
if a.stateChecker != nil && !a.stateChecker.IsConnectingTo(ssid) {
return "", dbus.NewError("net.connman.iwd.Agent.Error.Canceled", nil)
}
if a.prompts == nil {
if a.onUserCanceled != nil {
a.onUserCanceled()
}
return "", dbus.NewError("net.connman.iwd.Agent.Error.Canceled", nil)
}
if a.lastRequestSSID == ssid {
if a.onPromptRetry != nil {
a.onPromptRetry(ssid)
}
}
a.lastRequestSSID = ssid
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
token, err := a.prompts.Ask(ctx, PromptRequest{
SSID: ssid,
Fields: []string{"psk"},
})
if err != nil {
if a.onUserCanceled != nil {
a.onUserCanceled()
}
return "", dbus.NewError("net.connman.iwd.Agent.Error.Canceled", nil)
}
reply, err := a.prompts.Wait(ctx, token)
if err != nil {
if reply.Cancel || errors.Is(err, errdefs.ErrSecretPromptCancelled) {
if a.onUserCanceled != nil {
a.onUserCanceled()
}
}
return "", dbus.NewError("net.connman.iwd.Agent.Error.Canceled", nil)
}
if passphrase, ok := reply.Secrets["psk"]; ok {
return passphrase, nil
}
if a.onUserCanceled != nil {
a.onUserCanceled()
}
return "", dbus.NewError("net.connman.iwd.Agent.Error.Canceled", nil)
}
func (a *IWDAgent) RequestPrivateKeyPassphrase(network dbus.ObjectPath) (string, *dbus.Error) {
ssid := a.getNetworkName(network)
if a.stateChecker != nil && !a.stateChecker.IsConnectingTo(ssid) {
return "", dbus.NewError("net.connman.iwd.Agent.Error.Canceled", nil)
}
if a.prompts == nil {
return "", dbus.NewError("net.connman.iwd.Agent.Error.Canceled", nil)
}
if a.lastRequestSSID == ssid {
if a.onPromptRetry != nil {
a.onPromptRetry(ssid)
}
}
a.lastRequestSSID = ssid
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
token, err := a.prompts.Ask(ctx, PromptRequest{
SSID: ssid,
Fields: []string{"private-key-password"},
})
if err != nil {
return "", dbus.NewError("net.connman.iwd.Agent.Error.Canceled", nil)
}
reply, err := a.prompts.Wait(ctx, token)
if err != nil || reply.Cancel {
return "", dbus.NewError("net.connman.iwd.Agent.Error.Canceled", nil)
}
if passphrase, ok := reply.Secrets["private-key-password"]; ok {
return passphrase, nil
}
return "", dbus.NewError("net.connman.iwd.Agent.Error.Canceled", nil)
}
func (a *IWDAgent) RequestUserNameAndPassword(network dbus.ObjectPath) (string, string, *dbus.Error) {
ssid := a.getNetworkName(network)
if a.stateChecker != nil && !a.stateChecker.IsConnectingTo(ssid) {
return "", "", dbus.NewError("net.connman.iwd.Agent.Error.Canceled", nil)
}
if a.prompts == nil {
return "", "", dbus.NewError("net.connman.iwd.Agent.Error.Canceled", nil)
}
if a.lastRequestSSID == ssid {
if a.onPromptRetry != nil {
a.onPromptRetry(ssid)
}
}
a.lastRequestSSID = ssid
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
token, err := a.prompts.Ask(ctx, PromptRequest{
SSID: ssid,
Fields: []string{"identity", "password"},
})
if err != nil {
return "", "", dbus.NewError("net.connman.iwd.Agent.Error.Canceled", nil)
}
reply, err := a.prompts.Wait(ctx, token)
if err != nil || reply.Cancel {
return "", "", dbus.NewError("net.connman.iwd.Agent.Error.Canceled", nil)
}
username, hasUser := reply.Secrets["identity"]
password, hasPass := reply.Secrets["password"]
if hasUser && hasPass {
return username, password, nil
}
return "", "", dbus.NewError("net.connman.iwd.Agent.Error.Canceled", nil)
}
func (a *IWDAgent) RequestUserPassword(network dbus.ObjectPath, user string) (string, *dbus.Error) {
ssid := a.getNetworkName(network)
if a.stateChecker != nil && !a.stateChecker.IsConnectingTo(ssid) {
return "", dbus.NewError("net.connman.iwd.Agent.Error.Canceled", nil)
}
if a.prompts == nil {
return "", dbus.NewError("net.connman.iwd.Agent.Error.Canceled", nil)
}
if a.lastRequestSSID == ssid {
if a.onPromptRetry != nil {
a.onPromptRetry(ssid)
}
}
a.lastRequestSSID = ssid
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
token, err := a.prompts.Ask(ctx, PromptRequest{
SSID: ssid,
Fields: []string{"password"},
})
if err != nil {
return "", dbus.NewError("net.connman.iwd.Agent.Error.Canceled", nil)
}
reply, err := a.prompts.Wait(ctx, token)
if err != nil || reply.Cancel {
return "", dbus.NewError("net.connman.iwd.Agent.Error.Canceled", nil)
}
if password, ok := reply.Secrets["password"]; ok {
return password, nil
}
return "", dbus.NewError("net.connman.iwd.Agent.Error.Canceled", nil)
}
func (a *IWDAgent) Cancel(reason string) *dbus.Error {
return nil
}
func (a *IWDAgent) Release() *dbus.Error {
return nil
}
func (a *IWDAgent) Introspect() (string, *dbus.Error) {
return iwdAgentIntrospectXML, nil
}

View File

@@ -0,0 +1,528 @@
package network
import (
"context"
"errors"
"fmt"
"time"
"github.com/AvengeMedia/danklinux/internal/errdefs"
"github.com/AvengeMedia/danklinux/internal/log"
"github.com/godbus/dbus/v5"
)
const (
nmAgentManagerPath = "/org/freedesktop/NetworkManager/AgentManager"
nmAgentManagerIface = "org.freedesktop.NetworkManager.AgentManager"
nmSecretAgentIface = "org.freedesktop.NetworkManager.SecretAgent"
agentObjectPath = "/org/freedesktop/NetworkManager/SecretAgent"
agentIdentifier = "com.danklinux.NMAgent"
)
type SecretAgent struct {
conn *dbus.Conn
objPath dbus.ObjectPath
id string
prompts PromptBroker
manager *Manager
backend *NetworkManagerBackend
}
type nmVariantMap map[string]dbus.Variant
type nmSettingMap map[string]nmVariantMap
const introspectXML = `
<node>
<interface name="org.freedesktop.NetworkManager.SecretAgent">
<method name="GetSecrets">
<arg type="a{sa{sv}}" name="connection" direction="in"/>
<arg type="o" name="connection_path" direction="in"/>
<arg type="s" name="setting_name" direction="in"/>
<arg type="as" name="hints" direction="in"/>
<arg type="u" name="flags" direction="in"/>
<arg type="a{sa{sv}}" name="secrets" direction="out"/>
</method>
<method name="DeleteSecrets">
<arg type="a{sa{sv}}" name="connection" direction="in"/>
<arg type="o" name="connection_path" direction="in"/>
</method>
<method name="DeleteSecrets2">
<arg type="o" name="connection_path" direction="in"/>
<arg type="s" name="setting" direction="in"/>
</method>
<method name="CancelGetSecrets">
<arg type="o" name="connection_path" direction="in"/>
<arg type="s" name="setting_name" direction="in"/>
</method>
</interface>
<interface name="org.freedesktop.DBus.Introspectable">
<method name="Introspect">
<arg name="data" type="s" direction="out"/>
</method>
</interface>
</node>`
func NewSecretAgent(prompts PromptBroker, manager *Manager, backend *NetworkManagerBackend) (*SecretAgent, error) {
c, err := dbus.ConnectSystemBus()
if err != nil {
return nil, fmt.Errorf("failed to connect to system bus: %w", err)
}
sa := &SecretAgent{
conn: c,
objPath: dbus.ObjectPath(agentObjectPath),
id: agentIdentifier,
prompts: prompts,
manager: manager,
backend: backend,
}
if err := c.Export(sa, sa.objPath, nmSecretAgentIface); err != nil {
c.Close()
return nil, fmt.Errorf("failed to export secret agent: %w", err)
}
if err := c.Export(sa, sa.objPath, "org.freedesktop.DBus.Introspectable"); err != nil {
c.Close()
return nil, fmt.Errorf("failed to export introspection: %w", err)
}
mgr := c.Object("org.freedesktop.NetworkManager", dbus.ObjectPath(nmAgentManagerPath))
call := mgr.Call(nmAgentManagerIface+".Register", 0, sa.id)
if call.Err != nil {
c.Close()
return nil, fmt.Errorf("failed to register agent with NetworkManager: %w", call.Err)
}
log.Infof("[SecretAgent] Registered with NetworkManager (id=%s, unique name=%s, fixed path=%s)", sa.id, c.Names()[0], sa.objPath)
return sa, nil
}
func (a *SecretAgent) Close() {
if a.conn != nil {
mgr := a.conn.Object("org.freedesktop.NetworkManager", dbus.ObjectPath(nmAgentManagerPath))
mgr.Call(nmAgentManagerIface+".Unregister", 0, a.id)
a.conn.Close()
}
}
func (a *SecretAgent) GetSecrets(
conn map[string]nmVariantMap,
path dbus.ObjectPath,
settingName string,
hints []string,
flags uint32,
) (nmSettingMap, *dbus.Error) {
log.Infof("[SecretAgent] GetSecrets called: path=%s, setting=%s, hints=%v, flags=%d",
path, settingName, hints, flags)
const (
NM_SECRET_AGENT_GET_SECRETS_FLAG_ALLOW_INTERACTION = 0x1
NM_SECRET_AGENT_GET_SECRETS_FLAG_REQUEST_NEW = 0x2
NM_SECRET_AGENT_GET_SECRETS_FLAG_USER_REQUESTED = 0x4
)
connType, displayName, vpnSvc := readConnTypeAndName(conn)
ssid := readSSID(conn)
fields := fieldsNeeded(settingName, hints)
log.Infof("[SecretAgent] connType=%s, name=%s, vpnSvc=%s, fields=%v, flags=%d", connType, displayName, vpnSvc, fields, flags)
if a.backend != nil {
a.backend.stateMutex.RLock()
isConnecting := a.backend.state.IsConnecting
connectingSSID := a.backend.state.ConnectingSSID
isConnectingVPN := a.backend.state.IsConnectingVPN
connectingVPNUUID := a.backend.state.ConnectingVPNUUID
a.backend.stateMutex.RUnlock()
switch connType {
case "802-11-wireless":
// If we're connecting to a WiFi network, only respond if it's the one we're connecting to
if isConnecting && connectingSSID != ssid {
log.Infof("[SecretAgent] Ignoring WiFi request for SSID '%s' - we're connecting to '%s'", ssid, connectingSSID)
return nil, dbus.NewError("org.freedesktop.NetworkManager.SecretAgent.Error.NoSecrets", nil)
}
case "vpn", "wireguard":
var connUuid string
if c, ok := conn["connection"]; ok {
if v, ok := c["uuid"]; ok {
if s, ok2 := v.Value().(string); ok2 {
connUuid = s
}
}
}
// If we're connecting to a VPN, only respond if it's the one we're connecting to
// This prevents interfering with nmcli/other tools when our app isn't connecting
if isConnectingVPN && connUuid != connectingVPNUUID {
log.Infof("[SecretAgent] Ignoring VPN request for UUID '%s' - we're connecting to '%s'", connUuid, connectingVPNUUID)
return nil, dbus.NewError("org.freedesktop.NetworkManager.SecretAgent.Error.NoSecrets", nil)
}
}
}
if len(fields) == 0 {
// For VPN connections with no hints, we can't provide a proper UI.
// Defer to other agents (like nm-applet or VPN-specific auth dialogs)
// that can handle the VPN type properly (e.g., OpenConnect with SAML, etc.)
if settingName == "vpn" {
log.Infof("[SecretAgent] VPN with empty hints - deferring to other agents for %s", vpnSvc)
return nil, dbus.NewError("org.freedesktop.NetworkManager.SecretAgent.Error.NoSecrets", nil)
}
const (
NM_SETTING_SECRET_FLAG_NONE = 0
NM_SETTING_SECRET_FLAG_AGENT_OWNED = 1
NM_SETTING_SECRET_FLAG_NOT_SAVED = 2
NM_SETTING_SECRET_FLAG_NOT_REQUIRED = 4
)
var passwordFlags uint32 = 0xFFFF
switch settingName {
case "802-11-wireless-security":
if wifiSecSettings, ok := conn["802-11-wireless-security"]; ok {
if flagsVariant, ok := wifiSecSettings["psk-flags"]; ok {
if pwdFlags, ok := flagsVariant.Value().(uint32); ok {
passwordFlags = pwdFlags
}
}
}
case "802-1x":
if dot1xSettings, ok := conn["802-1x"]; ok {
if flagsVariant, ok := dot1xSettings["password-flags"]; ok {
if pwdFlags, ok := flagsVariant.Value().(uint32); ok {
passwordFlags = pwdFlags
}
}
}
}
if passwordFlags == 0xFFFF {
log.Warnf("[SecretAgent] Could not determine password-flags for empty hints - returning NoSecrets error")
return nil, dbus.NewError("org.freedesktop.NetworkManager.SecretAgent.Error.NoSecrets", nil)
} else if passwordFlags&NM_SETTING_SECRET_FLAG_NOT_REQUIRED != 0 {
log.Infof("[SecretAgent] Secrets not required (flags=%d)", passwordFlags)
out := nmSettingMap{}
out[settingName] = nmVariantMap{}
return out, nil
} else if passwordFlags&NM_SETTING_SECRET_FLAG_AGENT_OWNED != 0 {
log.Warnf("[SecretAgent] Secrets are agent-owned but we don't store secrets (flags=%d) - returning NoSecrets error", passwordFlags)
return nil, dbus.NewError("org.freedesktop.NetworkManager.SecretAgent.Error.NoSecrets", nil)
} else {
log.Infof("[SecretAgent] No secrets needed, using system stored secrets (flags=%d)", passwordFlags)
out := nmSettingMap{}
out[settingName] = nmVariantMap{}
return out, nil
}
}
reason := reasonFromFlags(flags)
if a.manager != nil && connType == "802-11-wireless" && a.manager.WasRecentlyFailed(ssid) {
reason = "wrong-password"
}
var connId, connUuid string
if c, ok := conn["connection"]; ok {
if v, ok := c["id"]; ok {
if s, ok2 := v.Value().(string); ok2 {
connId = s
}
}
if v, ok := c["uuid"]; ok {
if s, ok2 := v.Value().(string); ok2 {
connUuid = s
}
}
}
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
token, err := a.prompts.Ask(ctx, PromptRequest{
Name: displayName,
SSID: ssid,
ConnType: connType,
VpnService: vpnSvc,
SettingName: settingName,
Fields: fields,
Hints: hints,
Reason: reason,
ConnectionId: connId,
ConnectionUuid: connUuid,
ConnectionPath: string(path),
})
if err != nil {
log.Warnf("[SecretAgent] Failed to create prompt: %v", err)
return nil, dbus.MakeFailedError(err)
}
log.Infof("[SecretAgent] Waiting for user input (token=%s)", token)
reply, err := a.prompts.Wait(ctx, token)
if err != nil {
log.Warnf("[SecretAgent] Prompt failed or cancelled: %v", err)
// Clear connecting state immediately on cancellation
if a.backend != nil {
a.backend.stateMutex.Lock()
wasConnecting := a.backend.state.IsConnecting
wasConnectingVPN := a.backend.state.IsConnectingVPN
cancelledSSID := a.backend.state.ConnectingSSID
if wasConnecting || wasConnectingVPN {
log.Infof("[SecretAgent] Clearing connecting state due to cancelled prompt")
a.backend.state.IsConnecting = false
a.backend.state.ConnectingSSID = ""
a.backend.state.IsConnectingVPN = false
a.backend.state.ConnectingVPNUUID = ""
}
a.backend.stateMutex.Unlock()
// If this was a WiFi connection that was just cancelled, remove the connection profile
// (it was created with AddConnection but activation was cancelled)
if wasConnecting && cancelledSSID != "" && connType == "802-11-wireless" {
log.Infof("[SecretAgent] Removing connection profile for cancelled WiFi connection: %s", cancelledSSID)
if err := a.backend.ForgetWiFiNetwork(cancelledSSID); err != nil {
log.Warnf("[SecretAgent] Failed to remove cancelled connection profile: %v", err)
}
}
if (wasConnecting || wasConnectingVPN) && a.backend.onStateChange != nil {
a.backend.onStateChange()
}
}
if reply.Cancel || errors.Is(err, errdefs.ErrSecretPromptCancelled) {
return nil, dbus.NewError("org.freedesktop.NetworkManager.SecretAgent.Error.UserCanceled", nil)
}
if errors.Is(err, errdefs.ErrSecretPromptTimeout) {
return nil, dbus.NewError("org.freedesktop.NetworkManager.SecretAgent.Error.Failed", nil)
}
return nil, dbus.NewError("org.freedesktop.NetworkManager.SecretAgent.Error.Failed", nil)
}
log.Infof("[SecretAgent] User provided secrets, save=%v", reply.Save)
out := nmSettingMap{}
sec := nmVariantMap{}
for k, v := range reply.Secrets {
sec[k] = dbus.MakeVariant(v)
}
out[settingName] = sec
switch settingName {
case "802-1x":
log.Infof("[SecretAgent] Returning 802-1x enterprise secrets with %d fields", len(sec))
case "vpn":
log.Infof("[SecretAgent] Returning VPN secrets with %d fields for %s", len(sec), vpnSvc)
}
// If save=true, persist secrets in background after returning to NetworkManager
// This MUST happen after we return secrets, in a goroutine
if reply.Save {
go func() {
log.Infof("[SecretAgent] Persisting secrets with Update2: path=%s, setting=%s", path, settingName)
// Get existing connection settings
connObj := a.conn.Object("org.freedesktop.NetworkManager", path)
var existingSettings map[string]map[string]dbus.Variant
if err := connObj.Call("org.freedesktop.NetworkManager.Settings.Connection.GetSettings", 0).Store(&existingSettings); err != nil {
log.Warnf("[SecretAgent] GetSettings failed: %v", err)
return
}
// Build minimal settings with ONLY the section we're updating
// This avoids D-Bus type serialization issues with complex types like IPv6 addresses
settings := make(map[string]map[string]dbus.Variant)
// Copy connection section (required for Update2)
if connSection, ok := existingSettings["connection"]; ok {
settings["connection"] = connSection
}
// Update settings based on type
switch settingName {
case "vpn":
// Set password-flags=0 and add secrets to vpn section
vpn, ok := existingSettings["vpn"]
if !ok {
vpn = make(map[string]dbus.Variant)
}
// Get existing data map (vpn.data is string->string)
var data map[string]string
if dataVariant, ok := vpn["data"]; ok {
if dm, ok := dataVariant.Value().(map[string]string); ok {
data = make(map[string]string)
for k, v := range dm {
data[k] = v
}
} else {
data = make(map[string]string)
}
} else {
data = make(map[string]string)
}
// Update password-flags to 0 (system-stored)
data["password-flags"] = "0"
vpn["data"] = dbus.MakeVariant(data)
// Add secrets (vpn.secrets is string->string)
secs := make(map[string]string)
for k, v := range reply.Secrets {
secs[k] = v
}
vpn["secrets"] = dbus.MakeVariant(secs)
settings["vpn"] = vpn
log.Infof("[SecretAgent] Updated VPN settings: password-flags=0, secrets with %d fields", len(secs))
case "802-11-wireless-security":
// Set psk-flags=0 for WiFi
wifiSec, ok := existingSettings["802-11-wireless-security"]
if !ok {
wifiSec = make(map[string]dbus.Variant)
}
wifiSec["psk-flags"] = dbus.MakeVariant(uint32(0))
// Add PSK secret
if psk, ok := reply.Secrets["psk"]; ok {
wifiSec["psk"] = dbus.MakeVariant(psk)
log.Infof("[SecretAgent] Updated WiFi settings: psk-flags=0")
}
settings["802-11-wireless-security"] = wifiSec
case "802-1x":
// Set password-flags=0 for 802.1x
dot1x, ok := existingSettings["802-1x"]
if !ok {
dot1x = make(map[string]dbus.Variant)
}
dot1x["password-flags"] = dbus.MakeVariant(uint32(0))
// Add password secret
if password, ok := reply.Secrets["password"]; ok {
dot1x["password"] = dbus.MakeVariant(password)
log.Infof("[SecretAgent] Updated 802.1x settings: password-flags=0")
}
settings["802-1x"] = dot1x
}
// Call Update2 with correct signature:
// Update2(IN settings, IN flags, IN args) -> OUT result
// flags: 0x1 = to-disk
var result map[string]dbus.Variant
err := connObj.Call("org.freedesktop.NetworkManager.Settings.Connection.Update2", 0,
settings, uint32(0x1), map[string]dbus.Variant{}).Store(&result)
if err != nil {
log.Warnf("[SecretAgent] Update2(to-disk) failed: %v", err)
} else {
log.Infof("[SecretAgent] Successfully persisted secrets to disk for %s", settingName)
}
}()
}
return out, nil
}
func (a *SecretAgent) DeleteSecrets(conn map[string]nmVariantMap, path dbus.ObjectPath) *dbus.Error {
ssid := readSSID(conn)
log.Infof("[SecretAgent] DeleteSecrets called: path=%s, SSID=%s", path, ssid)
return nil
}
func (a *SecretAgent) DeleteSecrets2(path dbus.ObjectPath, setting string) *dbus.Error {
log.Infof("[SecretAgent] DeleteSecrets2 (alternate) called: path=%s, setting=%s", path, setting)
return nil
}
func (a *SecretAgent) CancelGetSecrets(path dbus.ObjectPath, settingName string) *dbus.Error {
log.Infof("[SecretAgent] CancelGetSecrets called: path=%s, setting=%s", path, settingName)
if a.prompts != nil {
if err := a.prompts.Cancel(string(path), settingName); err != nil {
log.Warnf("[SecretAgent] Failed to cancel prompt: %v", err)
}
}
return nil
}
func (a *SecretAgent) Introspect() (string, *dbus.Error) {
return introspectXML, nil
}
func readSSID(conn map[string]nmVariantMap) string {
if w, ok := conn["802-11-wireless"]; ok {
if v, ok := w["ssid"]; ok {
if b, ok := v.Value().([]byte); ok {
return string(b)
}
if s, ok := v.Value().(string); ok {
return s
}
}
}
return ""
}
func readConnTypeAndName(conn map[string]nmVariantMap) (string, string, string) {
var connType, name, svc string
if c, ok := conn["connection"]; ok {
if v, ok := c["type"]; ok {
if s, ok2 := v.Value().(string); ok2 {
connType = s
}
}
if v, ok := c["id"]; ok {
if s, ok2 := v.Value().(string); ok2 {
name = s
}
}
}
if vpn, ok := conn["vpn"]; ok {
if v, ok := vpn["service-type"]; ok {
if s, ok2 := v.Value().(string); ok2 {
svc = s
}
}
}
if name == "" && connType == "802-11-wireless" {
name = readSSID(conn)
}
return connType, name, svc
}
func fieldsNeeded(setting string, hints []string) []string {
switch setting {
case "802-11-wireless-security":
return []string{"psk"}
case "802-1x":
return []string{"identity", "password"}
case "vpn":
return hints
default:
return []string{}
}
}
func reasonFromFlags(flags uint32) string {
const (
NM_SECRET_AGENT_GET_SECRETS_FLAG_NONE = 0x0
NM_SECRET_AGENT_GET_SECRETS_FLAG_ALLOW_INTERACTION = 0x1
NM_SECRET_AGENT_GET_SECRETS_FLAG_REQUEST_NEW = 0x2
NM_SECRET_AGENT_GET_SECRETS_FLAG_USER_REQUESTED = 0x4
NM_SECRET_AGENT_GET_SECRETS_FLAG_WPS_PBC_ACTIVE = 0x8
NM_SECRET_AGENT_GET_SECRETS_FLAG_ONLY_SYSTEM = 0x80000000
NM_SECRET_AGENT_GET_SECRETS_FLAG_NO_ERRORS = 0x40000000
)
if flags&NM_SECRET_AGENT_GET_SECRETS_FLAG_REQUEST_NEW != 0 {
return "wrong-password"
}
if flags&NM_SECRET_AGENT_GET_SECRETS_FLAG_USER_REQUESTED != 0 {
return "user-requested"
}
return "required"
}

View File

@@ -0,0 +1,65 @@
package network
type Backend interface {
Initialize() error
Close()
GetWiFiEnabled() (bool, error)
SetWiFiEnabled(enabled bool) error
ScanWiFi() error
GetWiFiNetworkDetails(ssid string) (*NetworkInfoResponse, error)
ConnectWiFi(req ConnectionRequest) error
DisconnectWiFi() error
ForgetWiFiNetwork(ssid string) error
SetWiFiAutoconnect(ssid string, autoconnect bool) error
GetWiredConnections() ([]WiredConnection, error)
GetWiredNetworkDetails(uuid string) (*WiredNetworkInfoResponse, error)
ConnectEthernet() error
DisconnectEthernet() error
ActivateWiredConnection(uuid string) error
ListVPNProfiles() ([]VPNProfile, error)
ListActiveVPN() ([]VPNActive, error)
ConnectVPN(uuidOrName string, singleActive bool) error
DisconnectVPN(uuidOrName string) error
DisconnectAllVPN() error
ClearVPNCredentials(uuidOrName string) error
GetCurrentState() (*BackendState, error)
StartMonitoring(onStateChange func()) error
StopMonitoring()
GetPromptBroker() PromptBroker
SetPromptBroker(broker PromptBroker) error
SubmitCredentials(token string, secrets map[string]string, save bool) error
CancelCredentials(token string) error
}
type BackendState struct {
Backend string
NetworkStatus NetworkStatus
EthernetIP string
EthernetDevice string
EthernetConnected bool
EthernetConnectionUuid string
WiFiIP string
WiFiDevice string
WiFiConnected bool
WiFiEnabled bool
WiFiSSID string
WiFiBSSID string
WiFiSignal uint8
WiFiNetworks []WiFiNetwork
WiredConnections []WiredConnection
VPNProfiles []VPNProfile
VPNActive []VPNActive
IsConnecting bool
ConnectingSSID string
IsConnectingVPN bool
ConnectingVPNUUID string
LastError string
}

View File

@@ -0,0 +1,198 @@
package network
import (
"fmt"
"sync"
)
type HybridIwdNetworkdBackend struct {
wifi *IWDBackend
l3 *SystemdNetworkdBackend
onStateChange func()
stateMutex sync.RWMutex
}
func NewHybridIwdNetworkdBackend(w *IWDBackend, n *SystemdNetworkdBackend) (*HybridIwdNetworkdBackend, error) {
return &HybridIwdNetworkdBackend{
wifi: w,
l3: n,
}, nil
}
func (b *HybridIwdNetworkdBackend) Initialize() error {
if err := b.wifi.Initialize(); err != nil {
return fmt.Errorf("iwd init: %w", err)
}
if err := b.l3.Initialize(); err != nil {
return fmt.Errorf("networkd init: %w", err)
}
return nil
}
func (b *HybridIwdNetworkdBackend) Close() {
b.wifi.Close()
b.l3.Close()
}
func (b *HybridIwdNetworkdBackend) StartMonitoring(onStateChange func()) error {
b.onStateChange = onStateChange
mergedCallback := func() {
ws, _ := b.wifi.GetCurrentState()
ls, _ := b.l3.GetCurrentState()
if ws != nil && ls != nil && ws.WiFiDevice != "" && ls.WiFiIP != "" {
b.wifi.MarkIPConfigSeen()
}
if b.onStateChange != nil {
b.onStateChange()
}
}
if err := b.wifi.StartMonitoring(mergedCallback); err != nil {
return fmt.Errorf("wifi monitoring: %w", err)
}
if err := b.l3.StartMonitoring(mergedCallback); err != nil {
return fmt.Errorf("l3 monitoring: %w", err)
}
return nil
}
func (b *HybridIwdNetworkdBackend) StopMonitoring() {
b.wifi.StopMonitoring()
b.l3.StopMonitoring()
}
func (b *HybridIwdNetworkdBackend) GetCurrentState() (*BackendState, error) {
ws, err := b.wifi.GetCurrentState()
if err != nil {
return nil, err
}
ls, err := b.l3.GetCurrentState()
if err != nil {
return nil, err
}
merged := *ws
merged.Backend = "iwd+networkd"
merged.WiFiIP = ls.WiFiIP
merged.EthernetConnected = ls.EthernetConnected
merged.EthernetIP = ls.EthernetIP
merged.EthernetDevice = ls.EthernetDevice
merged.EthernetConnectionUuid = ls.EthernetConnectionUuid
merged.WiredConnections = ls.WiredConnections
if ls.EthernetConnected && ls.EthernetIP != "" {
merged.NetworkStatus = StatusEthernet
} else if ws.WiFiConnected && ls.WiFiIP != "" {
merged.NetworkStatus = StatusWiFi
} else {
merged.NetworkStatus = StatusDisconnected
}
return &merged, nil
}
func (b *HybridIwdNetworkdBackend) GetWiFiEnabled() (bool, error) {
return b.wifi.GetWiFiEnabled()
}
func (b *HybridIwdNetworkdBackend) SetWiFiEnabled(enabled bool) error {
return b.wifi.SetWiFiEnabled(enabled)
}
func (b *HybridIwdNetworkdBackend) ScanWiFi() error {
return b.wifi.ScanWiFi()
}
func (b *HybridIwdNetworkdBackend) GetWiFiNetworkDetails(ssid string) (*NetworkInfoResponse, error) {
return b.wifi.GetWiFiNetworkDetails(ssid)
}
func (b *HybridIwdNetworkdBackend) ConnectWiFi(req ConnectionRequest) error {
if err := b.wifi.ConnectWiFi(req); err != nil {
return err
}
ws, err := b.wifi.GetCurrentState()
if err == nil && ws.WiFiDevice != "" {
b.l3.EnsureDhcpUp(ws.WiFiDevice)
}
return nil
}
func (b *HybridIwdNetworkdBackend) DisconnectWiFi() error {
return b.wifi.DisconnectWiFi()
}
func (b *HybridIwdNetworkdBackend) ForgetWiFiNetwork(ssid string) error {
return b.wifi.ForgetWiFiNetwork(ssid)
}
func (b *HybridIwdNetworkdBackend) GetWiredConnections() ([]WiredConnection, error) {
return b.l3.GetWiredConnections()
}
func (b *HybridIwdNetworkdBackend) GetWiredNetworkDetails(uuid string) (*WiredNetworkInfoResponse, error) {
return b.l3.GetWiredNetworkDetails(uuid)
}
func (b *HybridIwdNetworkdBackend) ConnectEthernet() error {
return b.l3.ConnectEthernet()
}
func (b *HybridIwdNetworkdBackend) DisconnectEthernet() error {
return b.l3.DisconnectEthernet()
}
func (b *HybridIwdNetworkdBackend) ActivateWiredConnection(uuid string) error {
return b.l3.ActivateWiredConnection(uuid)
}
func (b *HybridIwdNetworkdBackend) ListVPNProfiles() ([]VPNProfile, error) {
return []VPNProfile{}, nil
}
func (b *HybridIwdNetworkdBackend) ListActiveVPN() ([]VPNActive, error) {
return []VPNActive{}, nil
}
func (b *HybridIwdNetworkdBackend) ConnectVPN(uuidOrName string, singleActive bool) error {
return fmt.Errorf("VPN not supported in hybrid mode")
}
func (b *HybridIwdNetworkdBackend) DisconnectVPN(uuidOrName string) error {
return fmt.Errorf("VPN not supported in hybrid mode")
}
func (b *HybridIwdNetworkdBackend) DisconnectAllVPN() error {
return fmt.Errorf("VPN not supported in hybrid mode")
}
func (b *HybridIwdNetworkdBackend) ClearVPNCredentials(uuidOrName string) error {
return fmt.Errorf("VPN not supported in hybrid mode")
}
func (b *HybridIwdNetworkdBackend) GetPromptBroker() PromptBroker {
return b.wifi.GetPromptBroker()
}
func (b *HybridIwdNetworkdBackend) SetPromptBroker(broker PromptBroker) error {
return b.wifi.SetPromptBroker(broker)
}
func (b *HybridIwdNetworkdBackend) SubmitCredentials(token string, secrets map[string]string, save bool) error {
return b.wifi.SubmitCredentials(token, secrets, save)
}
func (b *HybridIwdNetworkdBackend) CancelCredentials(token string) error {
return b.wifi.CancelCredentials(token)
}
func (b *HybridIwdNetworkdBackend) SetWiFiAutoconnect(ssid string, autoconnect bool) error {
return b.wifi.SetWiFiAutoconnect(ssid, autoconnect)
}

View File

@@ -0,0 +1,135 @@
package network
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestHybridIwdNetworkdBackend_New(t *testing.T) {
wifi, _ := NewIWDBackend()
l3, _ := NewSystemdNetworkdBackend()
hybrid, err := NewHybridIwdNetworkdBackend(wifi, l3)
assert.NoError(t, err)
assert.NotNil(t, hybrid)
assert.NotNil(t, hybrid.wifi)
assert.NotNil(t, hybrid.l3)
}
func TestHybridIwdNetworkdBackend_GetCurrentState_MergesState(t *testing.T) {
wifi, _ := NewIWDBackend()
l3, _ := NewSystemdNetworkdBackend()
hybrid, _ := NewHybridIwdNetworkdBackend(wifi, l3)
wifi.state.WiFiConnected = true
wifi.state.WiFiSSID = "TestNetwork"
wifi.state.WiFiBSSID = "00:11:22:33:44:55"
wifi.state.WiFiSignal = 75
wifi.state.WiFiDevice = "wlan0"
l3.state.WiFiIP = "192.168.1.100"
l3.state.EthernetConnected = false
state, err := hybrid.GetCurrentState()
assert.NoError(t, err)
assert.NotNil(t, state)
assert.Equal(t, "iwd+networkd", state.Backend)
assert.Equal(t, "TestNetwork", state.WiFiSSID)
assert.Equal(t, "00:11:22:33:44:55", state.WiFiBSSID)
assert.Equal(t, uint8(75), state.WiFiSignal)
assert.Equal(t, "192.168.1.100", state.WiFiIP)
assert.True(t, state.WiFiConnected)
assert.False(t, state.EthernetConnected)
assert.Equal(t, StatusWiFi, state.NetworkStatus)
}
func TestHybridIwdNetworkdBackend_GetCurrentState_EthernetPriority(t *testing.T) {
wifi, _ := NewIWDBackend()
l3, _ := NewSystemdNetworkdBackend()
hybrid, _ := NewHybridIwdNetworkdBackend(wifi, l3)
wifi.state.WiFiConnected = true
wifi.state.WiFiSSID = "TestNetwork"
l3.state.WiFiIP = "192.168.1.100"
l3.state.EthernetConnected = true
l3.state.EthernetIP = "192.168.1.50"
l3.state.EthernetDevice = "eth0"
state, err := hybrid.GetCurrentState()
assert.NoError(t, err)
assert.Equal(t, StatusEthernet, state.NetworkStatus)
assert.Equal(t, "192.168.1.50", state.EthernetIP)
assert.Equal(t, "eth0", state.EthernetDevice)
}
func TestHybridIwdNetworkdBackend_GetCurrentState_WiFiNoIP(t *testing.T) {
wifi, _ := NewIWDBackend()
l3, _ := NewSystemdNetworkdBackend()
hybrid, _ := NewHybridIwdNetworkdBackend(wifi, l3)
wifi.state.WiFiConnected = true
wifi.state.WiFiSSID = "TestNetwork"
l3.state.WiFiIP = ""
l3.state.EthernetConnected = false
state, err := hybrid.GetCurrentState()
assert.NoError(t, err)
assert.Equal(t, StatusDisconnected, state.NetworkStatus)
assert.True(t, state.WiFiConnected)
assert.Empty(t, state.WiFiIP)
}
func TestHybridIwdNetworkdBackend_WiFiDelegation(t *testing.T) {
wifi, _ := NewIWDBackend()
l3, _ := NewSystemdNetworkdBackend()
hybrid, _ := NewHybridIwdNetworkdBackend(wifi, l3)
enabled, err := hybrid.GetWiFiEnabled()
assert.NoError(t, err)
assert.True(t, enabled)
state, err := hybrid.GetCurrentState()
assert.NoError(t, err)
assert.NotNil(t, state)
assert.Equal(t, "iwd+networkd", state.Backend)
}
func TestHybridIwdNetworkdBackend_WiredDelegation(t *testing.T) {
wifi, _ := NewIWDBackend()
l3, _ := NewSystemdNetworkdBackend()
hybrid, _ := NewHybridIwdNetworkdBackend(wifi, l3)
conns, err := hybrid.GetWiredConnections()
assert.NoError(t, err)
assert.Empty(t, conns)
}
func TestHybridIwdNetworkdBackend_VPNNotSupported(t *testing.T) {
wifi, _ := NewIWDBackend()
l3, _ := NewSystemdNetworkdBackend()
hybrid, _ := NewHybridIwdNetworkdBackend(wifi, l3)
profiles, err := hybrid.ListVPNProfiles()
assert.NoError(t, err)
assert.Empty(t, profiles)
active, err := hybrid.ListActiveVPN()
assert.NoError(t, err)
assert.Empty(t, active)
err = hybrid.ConnectVPN("test", false)
assert.Error(t, err)
assert.Contains(t, err.Error(), "not supported")
}
func TestHybridIwdNetworkdBackend_PromptBrokerDelegation(t *testing.T) {
wifi, _ := NewIWDBackend()
l3, _ := NewSystemdNetworkdBackend()
hybrid, _ := NewHybridIwdNetworkdBackend(wifi, l3)
broker := hybrid.GetPromptBroker()
assert.Nil(t, broker)
}

View File

@@ -0,0 +1,232 @@
package network
import (
"fmt"
"sync"
"time"
"github.com/godbus/dbus/v5"
)
const (
iwdBusName = "net.connman.iwd"
iwdObjectPath = "/"
iwdAdapterInterface = "net.connman.iwd.Adapter"
iwdDeviceInterface = "net.connman.iwd.Device"
iwdStationInterface = "net.connman.iwd.Station"
iwdNetworkInterface = "net.connman.iwd.Network"
iwdKnownNetworkInterface = "net.connman.iwd.KnownNetwork"
dbusObjectManager = "org.freedesktop.DBus.ObjectManager"
dbusPropertiesInterface = "org.freedesktop.DBus.Properties"
)
type connectAttempt struct {
ssid string
netPath dbus.ObjectPath
start time.Time
deadline time.Time
sawAuthish bool
connectedAt time.Time
sawIPConfig bool
sawPromptRetry bool
finalized bool
mu sync.Mutex
}
type IWDBackend struct {
conn *dbus.Conn
state *BackendState
stateMutex sync.RWMutex
promptBroker PromptBroker
onStateChange func()
devicePath dbus.ObjectPath
stationPath dbus.ObjectPath
adapterPath dbus.ObjectPath
iwdAgent *IWDAgent
stopChan chan struct{}
sigWG sync.WaitGroup
curAttempt *connectAttempt
attemptMutex sync.RWMutex
recentScans map[string]time.Time
recentScansMu sync.Mutex
}
func NewIWDBackend() (*IWDBackend, error) {
backend := &IWDBackend{
state: &BackendState{
Backend: "iwd",
WiFiEnabled: true,
},
stopChan: make(chan struct{}),
recentScans: make(map[string]time.Time),
}
return backend, nil
}
func (b *IWDBackend) Initialize() error {
conn, err := dbus.ConnectSystemBus()
if err != nil {
return fmt.Errorf("failed to connect to system bus: %w", err)
}
b.conn = conn
if err := b.discoverDevices(); err != nil {
conn.Close()
return fmt.Errorf("failed to discover iwd devices: %w", err)
}
if err := b.updateState(); err != nil {
conn.Close()
return fmt.Errorf("failed to get initial state: %w", err)
}
return nil
}
func (b *IWDBackend) Close() {
close(b.stopChan)
b.sigWG.Wait()
if b.iwdAgent != nil {
b.iwdAgent.Close()
}
if b.conn != nil {
b.conn.Close()
}
}
func (b *IWDBackend) discoverDevices() error {
obj := b.conn.Object(iwdBusName, iwdObjectPath)
var objects map[dbus.ObjectPath]map[string]map[string]dbus.Variant
err := obj.Call(dbusObjectManager+".GetManagedObjects", 0).Store(&objects)
if err != nil {
return fmt.Errorf("failed to get managed objects: %w", err)
}
for path, interfaces := range objects {
if _, hasStation := interfaces[iwdStationInterface]; hasStation {
b.stationPath = path
}
if _, hasDevice := interfaces[iwdDeviceInterface]; hasDevice {
b.devicePath = path
if devProps, ok := interfaces[iwdDeviceInterface]; ok {
if nameVar, ok := devProps["Name"]; ok {
if name, ok := nameVar.Value().(string); ok {
b.stateMutex.Lock()
b.state.WiFiDevice = name
b.stateMutex.Unlock()
}
}
}
}
if _, hasAdapter := interfaces[iwdAdapterInterface]; hasAdapter {
b.adapterPath = path
}
}
if b.stationPath == "" || b.devicePath == "" {
return fmt.Errorf("no WiFi device found")
}
return nil
}
func (b *IWDBackend) GetCurrentState() (*BackendState, error) {
state := *b.state
state.WiFiNetworks = append([]WiFiNetwork(nil), b.state.WiFiNetworks...)
state.WiredConnections = append([]WiredConnection(nil), b.state.WiredConnections...)
return &state, nil
}
func (b *IWDBackend) OnUserCanceledPrompt() {
b.stateMutex.RLock()
cancelledSSID := b.state.ConnectingSSID
b.stateMutex.RUnlock()
b.setConnectError("user-canceled")
if cancelledSSID != "" {
if err := b.ForgetWiFiNetwork(cancelledSSID); err != nil {
}
}
if b.onStateChange != nil {
b.onStateChange()
}
}
func (b *IWDBackend) OnPromptRetry(ssid string) {
b.attemptMutex.RLock()
att := b.curAttempt
b.attemptMutex.RUnlock()
if att != nil && att.ssid == ssid {
att.mu.Lock()
att.sawPromptRetry = true
att.mu.Unlock()
}
}
func (b *IWDBackend) MarkIPConfigSeen() {
b.attemptMutex.RLock()
att := b.curAttempt
b.attemptMutex.RUnlock()
if att != nil {
att.mu.Lock()
att.sawIPConfig = true
att.mu.Unlock()
}
}
func (b *IWDBackend) GetPromptBroker() PromptBroker {
return b.promptBroker
}
func (b *IWDBackend) SetPromptBroker(broker PromptBroker) error {
if broker == nil {
return fmt.Errorf("broker cannot be nil")
}
b.promptBroker = broker
return nil
}
func (b *IWDBackend) SubmitCredentials(token string, secrets map[string]string, save bool) error {
if b.promptBroker == nil {
return fmt.Errorf("prompt broker not initialized")
}
return b.promptBroker.Resolve(token, PromptReply{
Secrets: secrets,
Save: save,
Cancel: false,
})
}
func (b *IWDBackend) CancelCredentials(token string) error {
if b.promptBroker == nil {
return fmt.Errorf("prompt broker not initialized")
}
return b.promptBroker.Resolve(token, PromptReply{
Cancel: true,
})
}
func (b *IWDBackend) StopMonitoring() {
select {
case <-b.stopChan:
return
default:
close(b.stopChan)
}
b.sigWG.Wait()
}

View File

@@ -0,0 +1,355 @@
package network
import (
"fmt"
"time"
"github.com/godbus/dbus/v5"
)
func (b *IWDBackend) StartMonitoring(onStateChange func()) error {
b.onStateChange = onStateChange
if b.promptBroker != nil {
agent, err := NewIWDAgent(b.conn, b.promptBroker)
if err != nil {
return fmt.Errorf("failed to start IWD agent: %w", err)
}
agent.onUserCanceled = b.OnUserCanceledPrompt
agent.onPromptRetry = b.OnPromptRetry
b.iwdAgent = agent
}
sigChan := make(chan *dbus.Signal, 100)
b.conn.Signal(sigChan)
if b.devicePath != "" {
err := b.conn.AddMatchSignal(
dbus.WithMatchObjectPath(b.devicePath),
dbus.WithMatchInterface(dbusPropertiesInterface),
dbus.WithMatchMember("PropertiesChanged"),
)
if err != nil {
return fmt.Errorf("failed to add device signal match: %w", err)
}
}
if b.stationPath != "" {
err := b.conn.AddMatchSignal(
dbus.WithMatchObjectPath(b.stationPath),
dbus.WithMatchInterface(dbusPropertiesInterface),
dbus.WithMatchMember("PropertiesChanged"),
)
if err != nil {
return fmt.Errorf("failed to add station signal match: %w", err)
}
}
b.sigWG.Add(1)
go b.signalHandler(sigChan)
return nil
}
func (b *IWDBackend) signalHandler(sigChan chan *dbus.Signal) {
defer b.sigWG.Done()
for {
select {
case <-b.stopChan:
b.conn.RemoveSignal(sigChan)
close(sigChan)
return
case sig := <-sigChan:
if sig == nil {
return
}
if sig.Name != dbusPropertiesInterface+".PropertiesChanged" {
continue
}
if len(sig.Body) < 2 {
continue
}
iface, ok := sig.Body[0].(string)
if !ok {
continue
}
changed, ok := sig.Body[1].(map[string]dbus.Variant)
if !ok {
continue
}
stateChanged := false
switch iface {
case iwdDeviceInterface:
if sig.Path == b.devicePath {
if poweredVar, ok := changed["Powered"]; ok {
if powered, ok := poweredVar.Value().(bool); ok {
b.stateMutex.Lock()
if b.state.WiFiEnabled != powered {
b.state.WiFiEnabled = powered
stateChanged = true
}
b.stateMutex.Unlock()
}
}
}
case iwdStationInterface:
if sig.Path == b.stationPath {
if scanningVar, ok := changed["Scanning"]; ok {
if scanning, ok := scanningVar.Value().(bool); ok && !scanning {
networks, err := b.updateWiFiNetworks()
if err == nil {
b.stateMutex.Lock()
b.state.WiFiNetworks = networks
b.stateMutex.Unlock()
stateChanged = true
}
b.stateMutex.RLock()
wifiConnected := b.state.WiFiConnected
b.stateMutex.RUnlock()
if wifiConnected {
stationObj := b.conn.Object(iwdBusName, b.stationPath)
connNetVar, err := stationObj.GetProperty(iwdStationInterface + ".ConnectedNetwork")
if err == nil && connNetVar.Value() != nil {
if netPath, ok := connNetVar.Value().(dbus.ObjectPath); ok && netPath != "/" {
var orderedNetworks [][]dbus.Variant
err = stationObj.Call(iwdStationInterface+".GetOrderedNetworks", 0).Store(&orderedNetworks)
if err == nil {
for _, netData := range orderedNetworks {
if len(netData) < 2 {
continue
}
currentNetPath, ok := netData[0].Value().(dbus.ObjectPath)
if !ok || currentNetPath != netPath {
continue
}
signalStrength, ok := netData[1].Value().(int16)
if !ok {
continue
}
signalDbm := signalStrength / 100
signal := uint8(signalDbm + 100)
if signalDbm > 0 {
signal = 100
} else if signalDbm < -100 {
signal = 0
}
b.stateMutex.Lock()
if b.state.WiFiSignal != signal {
b.state.WiFiSignal = signal
stateChanged = true
}
b.stateMutex.Unlock()
break
}
}
}
}
}
}
}
if stateVar, ok := changed["State"]; ok {
if state, ok := stateVar.Value().(string); ok {
b.attemptMutex.RLock()
att := b.curAttempt
b.attemptMutex.RUnlock()
var connPath dbus.ObjectPath
if v, ok := changed["ConnectedNetwork"]; ok {
if v.Value() != nil {
if p, ok := v.Value().(dbus.ObjectPath); ok {
connPath = p
}
}
}
if connPath == "" {
station := b.conn.Object(iwdBusName, b.stationPath)
if cnVar, err := station.GetProperty(iwdStationInterface + ".ConnectedNetwork"); err == nil && cnVar.Value() != nil {
cnVar.Store(&connPath)
}
}
b.stateMutex.RLock()
prevConnected := b.state.WiFiConnected
prevSSID := b.state.WiFiSSID
b.stateMutex.RUnlock()
targetPath := dbus.ObjectPath("")
if att != nil {
targetPath = att.netPath
}
isTarget := att != nil && targetPath != "" && connPath == targetPath
if att != nil {
switch state {
case "authenticating", "associating", "associated", "roaming":
att.mu.Lock()
att.sawAuthish = true
att.mu.Unlock()
}
}
if att != nil && state == "connected" && isTarget {
att.mu.Lock()
if att.connectedAt.IsZero() {
att.connectedAt = time.Now()
}
att.mu.Unlock()
}
if att != nil && state == "configuring" {
att.mu.Lock()
att.sawIPConfig = true
att.mu.Unlock()
}
switch state {
case "connected":
b.stateMutex.Lock()
b.state.WiFiConnected = true
b.state.NetworkStatus = StatusWiFi
b.state.IsConnecting = false
b.state.ConnectingSSID = ""
b.state.LastError = ""
b.stateMutex.Unlock()
if connPath != "" && connPath != "/" {
netObj := b.conn.Object(iwdBusName, connPath)
if nameVar, err := netObj.GetProperty(iwdNetworkInterface + ".Name"); err == nil {
if name, ok := nameVar.Value().(string); ok {
b.stateMutex.Lock()
b.state.WiFiSSID = name
b.stateMutex.Unlock()
}
}
}
stateChanged = true
if att != nil && isTarget {
go func(attLocal *connectAttempt, tgt dbus.ObjectPath) {
time.Sleep(3 * time.Second)
station := b.conn.Object(iwdBusName, b.stationPath)
var nowState string
if stVar, err := station.GetProperty(iwdStationInterface + ".State"); err == nil {
stVar.Store(&nowState)
}
var nowConn dbus.ObjectPath
if cnVar, err := station.GetProperty(iwdStationInterface + ".ConnectedNetwork"); err == nil && cnVar.Value() != nil {
cnVar.Store(&nowConn)
}
if nowState == "connected" && nowConn == tgt {
b.finalizeAttempt(attLocal, "")
b.attemptMutex.Lock()
if b.curAttempt == attLocal {
b.curAttempt = nil
}
b.attemptMutex.Unlock()
}
}(att, targetPath)
}
case "disconnecting", "disconnected":
if att != nil {
wasConnectedToTarget := prevConnected && prevSSID == att.ssid
if wasConnectedToTarget || isTarget {
code := b.classifyAttempt(att)
b.finalizeAttempt(att, code)
b.attemptMutex.Lock()
if b.curAttempt == att {
b.curAttempt = nil
}
b.attemptMutex.Unlock()
}
}
b.stateMutex.Lock()
b.state.WiFiConnected = false
if state == "disconnected" {
b.state.NetworkStatus = StatusDisconnected
}
b.stateMutex.Unlock()
stateChanged = true
}
}
}
if connNetVar, ok := changed["ConnectedNetwork"]; ok {
if netPath, ok := connNetVar.Value().(dbus.ObjectPath); ok && netPath != "/" {
netObj := b.conn.Object(iwdBusName, netPath)
nameVar, err := netObj.GetProperty(iwdNetworkInterface + ".Name")
if err == nil {
if name, ok := nameVar.Value().(string); ok {
b.stateMutex.Lock()
if b.state.WiFiSSID != name {
b.state.WiFiSSID = name
stateChanged = true
}
b.stateMutex.Unlock()
}
}
stationObj := b.conn.Object(iwdBusName, b.stationPath)
var orderedNetworks [][]dbus.Variant
err = stationObj.Call(iwdStationInterface+".GetOrderedNetworks", 0).Store(&orderedNetworks)
if err == nil {
for _, netData := range orderedNetworks {
if len(netData) < 2 {
continue
}
currentNetPath, ok := netData[0].Value().(dbus.ObjectPath)
if !ok || currentNetPath != netPath {
continue
}
signalStrength, ok := netData[1].Value().(int16)
if !ok {
continue
}
signalDbm := signalStrength / 100
signal := uint8(signalDbm + 100)
if signalDbm > 0 {
signal = 100
} else if signalDbm < -100 {
signal = 0
}
b.stateMutex.Lock()
if b.state.WiFiSignal != signal {
b.state.WiFiSignal = signal
stateChanged = true
}
b.stateMutex.Unlock()
break
}
}
} else {
b.stateMutex.Lock()
if b.state.WiFiSSID != "" {
b.state.WiFiSSID = ""
b.state.WiFiSignal = 0
stateChanged = true
}
b.stateMutex.Unlock()
}
}
}
}
if stateChanged && b.onStateChange != nil {
b.onStateChange()
}
}
}
}

View File

@@ -0,0 +1,212 @@
package network
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestIWDBackend_MarkIPConfigSeen(t *testing.T) {
backend, _ := NewIWDBackend()
att := &connectAttempt{
ssid: "TestNetwork",
netPath: "/net/connman/iwd/0/1/test",
start: time.Now(),
deadline: time.Now().Add(15 * time.Second),
}
backend.attemptMutex.Lock()
backend.curAttempt = att
backend.attemptMutex.Unlock()
backend.MarkIPConfigSeen()
att.mu.Lock()
assert.True(t, att.sawIPConfig, "sawIPConfig should be true after MarkIPConfigSeen")
att.mu.Unlock()
}
func TestIWDBackend_MarkIPConfigSeen_NoAttempt(t *testing.T) {
backend, _ := NewIWDBackend()
backend.attemptMutex.Lock()
backend.curAttempt = nil
backend.attemptMutex.Unlock()
backend.MarkIPConfigSeen()
}
func TestIWDBackend_OnPromptRetry(t *testing.T) {
backend, _ := NewIWDBackend()
att := &connectAttempt{
ssid: "TestNetwork",
netPath: "/net/connman/iwd/0/1/test",
start: time.Now(),
deadline: time.Now().Add(15 * time.Second),
}
backend.attemptMutex.Lock()
backend.curAttempt = att
backend.attemptMutex.Unlock()
backend.OnPromptRetry("TestNetwork")
att.mu.Lock()
assert.True(t, att.sawPromptRetry, "sawPromptRetry should be true after OnPromptRetry")
att.mu.Unlock()
}
func TestIWDBackend_OnPromptRetry_WrongSSID(t *testing.T) {
backend, _ := NewIWDBackend()
att := &connectAttempt{
ssid: "TestNetwork",
netPath: "/net/connman/iwd/0/1/test",
start: time.Now(),
deadline: time.Now().Add(15 * time.Second),
}
backend.attemptMutex.Lock()
backend.curAttempt = att
backend.attemptMutex.Unlock()
backend.OnPromptRetry("DifferentNetwork")
att.mu.Lock()
assert.False(t, att.sawPromptRetry, "sawPromptRetry should remain false for different SSID")
att.mu.Unlock()
}
func TestIWDBackend_ClassifyAttempt_BadCredentials_PromptRetry(t *testing.T) {
backend, _ := NewIWDBackend()
att := &connectAttempt{
ssid: "TestNetwork",
netPath: "/test",
start: time.Now().Add(-5 * time.Second),
deadline: time.Now().Add(10 * time.Second),
sawPromptRetry: true,
}
code := backend.classifyAttempt(att)
assert.Equal(t, "bad-credentials", code)
}
func TestIWDBackend_ClassifyAttempt_DhcpTimeout(t *testing.T) {
backend, _ := NewIWDBackend()
att := &connectAttempt{
ssid: "TestNetwork",
netPath: "/test",
start: time.Now().Add(-13 * time.Second),
deadline: time.Now().Add(2 * time.Second),
sawAuthish: true,
sawIPConfig: false,
}
code := backend.classifyAttempt(att)
assert.Equal(t, "dhcp-timeout", code)
}
func TestIWDBackend_ClassifyAttempt_AssocTimeout(t *testing.T) {
backend, _ := NewIWDBackend()
att := &connectAttempt{
ssid: "TestNetwork",
netPath: "/test",
start: time.Now().Add(-5 * time.Second),
deadline: time.Now().Add(10 * time.Second),
}
backend.recentScansMu.Lock()
backend.recentScans["TestNetwork"] = time.Now()
backend.recentScansMu.Unlock()
code := backend.classifyAttempt(att)
assert.Equal(t, "assoc-timeout", code)
}
func TestIWDBackend_ClassifyAttempt_NoSuchSSID(t *testing.T) {
backend, _ := NewIWDBackend()
att := &connectAttempt{
ssid: "TestNetwork",
netPath: "/test",
start: time.Now().Add(-5 * time.Second),
deadline: time.Now().Add(10 * time.Second),
}
code := backend.classifyAttempt(att)
assert.Equal(t, "no-such-ssid", code)
}
func TestIWDBackend_MapIwdDBusError(t *testing.T) {
backend, _ := NewIWDBackend()
testCases := []struct {
name string
expected string
}{
{"net.connman.iwd.Error.AlreadyConnected", "already-connected"},
{"net.connman.iwd.Error.AuthenticationFailed", "bad-credentials"},
{"net.connman.iwd.Error.InvalidKey", "bad-credentials"},
{"net.connman.iwd.Error.IncorrectPassphrase", "bad-credentials"},
{"net.connman.iwd.Error.NotFound", "no-such-ssid"},
{"net.connman.iwd.Error.NotSupported", "connection-failed"},
{"net.connman.iwd.Agent.Error.Canceled", "user-canceled"},
{"net.connman.iwd.Error.Unknown", "connection-failed"},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
code := backend.mapIwdDBusError(tc.name)
assert.Equal(t, tc.expected, code)
})
}
}
func TestConnectAttempt_Finalization(t *testing.T) {
backend, _ := NewIWDBackend()
backend.state = &BackendState{}
att := &connectAttempt{
ssid: "TestNetwork",
netPath: "/test",
start: time.Now(),
deadline: time.Now().Add(15 * time.Second),
}
backend.finalizeAttempt(att, "bad-credentials")
att.mu.Lock()
assert.True(t, att.finalized)
att.mu.Unlock()
backend.stateMutex.RLock()
assert.False(t, backend.state.IsConnecting)
assert.Empty(t, backend.state.ConnectingSSID)
assert.Equal(t, "bad-credentials", backend.state.LastError)
backend.stateMutex.RUnlock()
}
func TestConnectAttempt_DoubleFinalization(t *testing.T) {
backend, _ := NewIWDBackend()
backend.state = &BackendState{}
att := &connectAttempt{
ssid: "TestNetwork",
netPath: "/test",
start: time.Now(),
deadline: time.Now().Add(15 * time.Second),
}
backend.finalizeAttempt(att, "bad-credentials")
backend.finalizeAttempt(att, "dhcp-timeout")
backend.stateMutex.RLock()
assert.Equal(t, "bad-credentials", backend.state.LastError)
backend.stateMutex.RUnlock()
}

View File

@@ -0,0 +1,47 @@
package network
import "fmt"
func (b *IWDBackend) GetWiredConnections() ([]WiredConnection, error) {
return nil, fmt.Errorf("wired connections not supported by iwd")
}
func (b *IWDBackend) GetWiredNetworkDetails(uuid string) (*WiredNetworkInfoResponse, error) {
return nil, fmt.Errorf("wired connections not supported by iwd")
}
func (b *IWDBackend) ConnectEthernet() error {
return fmt.Errorf("wired connections not supported by iwd")
}
func (b *IWDBackend) DisconnectEthernet() error {
return fmt.Errorf("wired connections not supported by iwd")
}
func (b *IWDBackend) ActivateWiredConnection(uuid string) error {
return fmt.Errorf("wired connections not supported by iwd")
}
func (b *IWDBackend) ListVPNProfiles() ([]VPNProfile, error) {
return nil, fmt.Errorf("VPN not supported by iwd backend")
}
func (b *IWDBackend) ListActiveVPN() ([]VPNActive, error) {
return nil, fmt.Errorf("VPN not supported by iwd backend")
}
func (b *IWDBackend) ConnectVPN(uuidOrName string, singleActive bool) error {
return fmt.Errorf("VPN not supported by iwd backend")
}
func (b *IWDBackend) DisconnectVPN(uuidOrName string) error {
return fmt.Errorf("VPN not supported by iwd backend")
}
func (b *IWDBackend) DisconnectAllVPN() error {
return fmt.Errorf("VPN not supported by iwd backend")
}
func (b *IWDBackend) ClearVPNCredentials(uuidOrName string) error {
return fmt.Errorf("VPN not supported by iwd backend")
}

View File

@@ -0,0 +1,662 @@
package network
import (
"fmt"
"time"
"github.com/AvengeMedia/danklinux/internal/errdefs"
"github.com/godbus/dbus/v5"
)
func (b *IWDBackend) updateState() error {
if b.devicePath == "" {
return nil
}
obj := b.conn.Object(iwdBusName, b.devicePath)
poweredVar, err := obj.GetProperty(iwdDeviceInterface + ".Powered")
if err == nil {
if powered, ok := poweredVar.Value().(bool); ok {
b.stateMutex.Lock()
b.state.WiFiEnabled = powered
b.stateMutex.Unlock()
}
}
if b.stationPath == "" {
return nil
}
stationObj := b.conn.Object(iwdBusName, b.stationPath)
stateVar, err := stationObj.GetProperty(iwdStationInterface + ".State")
if err == nil {
if state, ok := stateVar.Value().(string); ok {
b.stateMutex.Lock()
b.state.WiFiConnected = (state == "connected")
if state == "connected" {
b.state.NetworkStatus = StatusWiFi
} else {
b.state.NetworkStatus = StatusDisconnected
}
b.stateMutex.Unlock()
}
}
connNetVar, err := stationObj.GetProperty(iwdStationInterface + ".ConnectedNetwork")
if err == nil && connNetVar.Value() != nil {
if netPath, ok := connNetVar.Value().(dbus.ObjectPath); ok && netPath != "/" {
netObj := b.conn.Object(iwdBusName, netPath)
nameVar, err := netObj.GetProperty(iwdNetworkInterface + ".Name")
if err == nil {
if name, ok := nameVar.Value().(string); ok {
b.stateMutex.Lock()
b.state.WiFiSSID = name
b.stateMutex.Unlock()
}
}
var orderedNetworks [][]dbus.Variant
err = stationObj.Call(iwdStationInterface+".GetOrderedNetworks", 0).Store(&orderedNetworks)
if err == nil {
for _, netData := range orderedNetworks {
if len(netData) < 2 {
continue
}
currentNetPath, ok := netData[0].Value().(dbus.ObjectPath)
if !ok || currentNetPath != netPath {
continue
}
signalStrength, ok := netData[1].Value().(int16)
if !ok {
continue
}
signalDbm := signalStrength / 100
signal := uint8(signalDbm + 100)
if signalDbm > 0 {
signal = 100
} else if signalDbm < -100 {
signal = 0
}
b.stateMutex.Lock()
b.state.WiFiSignal = signal
b.stateMutex.Unlock()
break
}
}
}
}
networks, err := b.updateWiFiNetworks()
if err == nil {
b.stateMutex.Lock()
b.state.WiFiNetworks = networks
b.stateMutex.Unlock()
}
return nil
}
func (b *IWDBackend) GetWiFiEnabled() (bool, error) {
b.stateMutex.RLock()
defer b.stateMutex.RUnlock()
return b.state.WiFiEnabled, nil
}
func (b *IWDBackend) SetWiFiEnabled(enabled bool) error {
if b.devicePath == "" {
return fmt.Errorf("no WiFi device available")
}
obj := b.conn.Object(iwdBusName, b.devicePath)
call := obj.Call(dbusPropertiesInterface+".Set", 0, iwdDeviceInterface, "Powered", dbus.MakeVariant(enabled))
if call.Err != nil {
return fmt.Errorf("failed to set WiFi enabled: %w", call.Err)
}
b.stateMutex.Lock()
b.state.WiFiEnabled = enabled
b.stateMutex.Unlock()
if b.onStateChange != nil {
b.onStateChange()
}
return nil
}
func (b *IWDBackend) ScanWiFi() error {
if b.stationPath == "" {
return fmt.Errorf("no WiFi device available")
}
obj := b.conn.Object(iwdBusName, b.stationPath)
scanningVar, err := obj.GetProperty(iwdStationInterface + ".Scanning")
if err != nil {
return fmt.Errorf("failed to check scanning state: %w", err)
}
if scanning, ok := scanningVar.Value().(bool); ok && scanning {
return fmt.Errorf("scan already in progress")
}
call := obj.Call(iwdStationInterface+".Scan", 0)
if call.Err != nil {
return fmt.Errorf("scan request failed: %w", call.Err)
}
return nil
}
func (b *IWDBackend) updateWiFiNetworks() ([]WiFiNetwork, error) {
if b.stationPath == "" {
return nil, fmt.Errorf("no WiFi device available")
}
obj := b.conn.Object(iwdBusName, b.stationPath)
var orderedNetworks [][]dbus.Variant
err := obj.Call(iwdStationInterface+".GetOrderedNetworks", 0).Store(&orderedNetworks)
if err != nil {
return nil, fmt.Errorf("failed to get networks: %w", err)
}
knownNetworks, err := b.getKnownNetworks()
if err != nil {
knownNetworks = make(map[string]bool)
}
autoconnectMap, err := b.getAutoconnectSettings()
if err != nil {
autoconnectMap = make(map[string]bool)
}
b.stateMutex.RLock()
currentSSID := b.state.WiFiSSID
wifiConnected := b.state.WiFiConnected
b.stateMutex.RUnlock()
networks := make([]WiFiNetwork, 0, len(orderedNetworks))
for _, netData := range orderedNetworks {
if len(netData) < 2 {
continue
}
networkPath, ok := netData[0].Value().(dbus.ObjectPath)
if !ok {
continue
}
signalStrength, ok := netData[1].Value().(int16)
if !ok {
continue
}
netObj := b.conn.Object(iwdBusName, networkPath)
nameVar, err := netObj.GetProperty(iwdNetworkInterface + ".Name")
if err != nil {
continue
}
name, ok := nameVar.Value().(string)
if !ok {
continue
}
typeVar, err := netObj.GetProperty(iwdNetworkInterface + ".Type")
if err != nil {
continue
}
netType, ok := typeVar.Value().(string)
if !ok {
continue
}
signalDbm := signalStrength / 100
signal := uint8(signalDbm + 100)
if signalDbm > 0 {
signal = 100
} else if signalDbm < -100 {
signal = 0
}
secured := netType != "open"
network := WiFiNetwork{
SSID: name,
Signal: signal,
Secured: secured,
Connected: wifiConnected && name == currentSSID,
Saved: knownNetworks[name],
Autoconnect: autoconnectMap[name],
Enterprise: netType == "8021x",
}
networks = append(networks, network)
}
sortWiFiNetworks(networks)
b.stateMutex.Lock()
b.state.WiFiNetworks = networks
b.stateMutex.Unlock()
now := time.Now()
b.recentScansMu.Lock()
for _, net := range networks {
b.recentScans[net.SSID] = now
}
b.recentScansMu.Unlock()
return networks, nil
}
func (b *IWDBackend) getKnownNetworks() (map[string]bool, error) {
obj := b.conn.Object(iwdBusName, iwdObjectPath)
var objects map[dbus.ObjectPath]map[string]map[string]dbus.Variant
err := obj.Call(dbusObjectManager+".GetManagedObjects", 0).Store(&objects)
if err != nil {
return nil, err
}
known := make(map[string]bool)
for _, interfaces := range objects {
if knownProps, ok := interfaces[iwdKnownNetworkInterface]; ok {
if nameVar, ok := knownProps["Name"]; ok {
if name, ok := nameVar.Value().(string); ok {
known[name] = true
}
}
}
}
return known, nil
}
func (b *IWDBackend) getAutoconnectSettings() (map[string]bool, error) {
obj := b.conn.Object(iwdBusName, iwdObjectPath)
var objects map[dbus.ObjectPath]map[string]map[string]dbus.Variant
err := obj.Call(dbusObjectManager+".GetManagedObjects", 0).Store(&objects)
if err != nil {
return nil, err
}
autoconnectMap := make(map[string]bool)
for _, interfaces := range objects {
if knownProps, ok := interfaces[iwdKnownNetworkInterface]; ok {
if nameVar, ok := knownProps["Name"]; ok {
if name, ok := nameVar.Value().(string); ok {
autoconnect := true
if acVar, ok := knownProps["AutoConnect"]; ok {
if ac, ok := acVar.Value().(bool); ok {
autoconnect = ac
}
}
autoconnectMap[name] = autoconnect
}
}
}
}
return autoconnectMap, nil
}
func (b *IWDBackend) GetWiFiNetworkDetails(ssid string) (*NetworkInfoResponse, error) {
b.stateMutex.RLock()
networks := b.state.WiFiNetworks
b.stateMutex.RUnlock()
var found *WiFiNetwork
for i := range networks {
if networks[i].SSID == ssid {
found = &networks[i]
break
}
}
if found == nil {
return nil, fmt.Errorf("network not found: %s", ssid)
}
return &NetworkInfoResponse{
SSID: ssid,
Bands: []WiFiNetwork{*found},
}, nil
}
func (b *IWDBackend) setConnectError(code string) {
b.stateMutex.Lock()
b.state.IsConnecting = false
b.state.ConnectingSSID = ""
b.state.LastError = code
b.stateMutex.Unlock()
}
func (b *IWDBackend) seenInRecentScan(ssid string) bool {
b.recentScansMu.Lock()
defer b.recentScansMu.Unlock()
lastSeen, ok := b.recentScans[ssid]
return ok && time.Since(lastSeen) < 30*time.Second
}
func (b *IWDBackend) classifyAttempt(att *connectAttempt) string {
att.mu.Lock()
defer att.mu.Unlock()
if att.sawPromptRetry {
return errdefs.ErrBadCredentials
}
if !att.connectedAt.IsZero() && !att.sawIPConfig {
connDuration := time.Since(att.connectedAt)
if connDuration > 500*time.Millisecond && connDuration < 3*time.Second {
return errdefs.ErrBadCredentials
}
}
if (att.sawAuthish || !att.connectedAt.IsZero()) && !att.sawIPConfig {
if time.Since(att.start) > 12*time.Second {
return errdefs.ErrDhcpTimeout
}
}
if !att.sawAuthish && att.connectedAt.IsZero() {
if !b.seenInRecentScan(att.ssid) {
return errdefs.ErrNoSuchSSID
}
return errdefs.ErrAssocTimeout
}
return errdefs.ErrAssocTimeout
}
func (b *IWDBackend) finalizeAttempt(att *connectAttempt, code string) {
att.mu.Lock()
if att.finalized {
att.mu.Unlock()
return
}
att.finalized = true
att.mu.Unlock()
b.stateMutex.Lock()
b.state.IsConnecting = false
b.state.ConnectingSSID = ""
b.state.LastError = code
b.stateMutex.Unlock()
b.updateState()
if b.onStateChange != nil {
b.onStateChange()
}
}
func (b *IWDBackend) startAttemptWatchdog(att *connectAttempt) {
b.sigWG.Add(1)
go func() {
defer b.sigWG.Done()
ticker := time.NewTicker(250 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ticker.C:
att.mu.Lock()
finalized := att.finalized
att.mu.Unlock()
if finalized || time.Now().After(att.deadline) {
if !finalized {
b.finalizeAttempt(att, b.classifyAttempt(att))
}
return
}
station := b.conn.Object(iwdBusName, b.stationPath)
stVar, err := station.GetProperty(iwdStationInterface + ".State")
if err != nil {
continue
}
state, _ := stVar.Value().(string)
cnVar, err := station.GetProperty(iwdStationInterface + ".ConnectedNetwork")
if err != nil {
continue
}
var connPath dbus.ObjectPath
if cnVar.Value() != nil {
connPath, _ = cnVar.Value().(dbus.ObjectPath)
}
att.mu.Lock()
if connPath == att.netPath && state == "connected" && att.connectedAt.IsZero() {
att.connectedAt = time.Now()
}
if state == "configuring" {
att.sawIPConfig = true
}
att.mu.Unlock()
case <-b.stopChan:
return
}
}
}()
}
func (b *IWDBackend) mapIwdDBusError(name string) string {
switch name {
case "net.connman.iwd.Error.AlreadyConnected":
return errdefs.ErrAlreadyConnected
case "net.connman.iwd.Error.AuthenticationFailed",
"net.connman.iwd.Error.InvalidKey",
"net.connman.iwd.Error.IncorrectPassphrase":
return errdefs.ErrBadCredentials
case "net.connman.iwd.Error.NotFound":
return errdefs.ErrNoSuchSSID
case "net.connman.iwd.Error.NotSupported":
return errdefs.ErrConnectionFailed
case "net.connman.iwd.Agent.Error.Canceled":
return errdefs.ErrUserCanceled
default:
return errdefs.ErrConnectionFailed
}
}
func (b *IWDBackend) ConnectWiFi(req ConnectionRequest) error {
if b.stationPath == "" {
b.setConnectError(errdefs.ErrWifiDisabled)
if b.onStateChange != nil {
b.onStateChange()
}
return fmt.Errorf("no WiFi device available")
}
networkPath, err := b.findNetworkPath(req.SSID)
if err != nil {
b.setConnectError(errdefs.ErrNoSuchSSID)
if b.onStateChange != nil {
b.onStateChange()
}
return fmt.Errorf("network not found: %w", err)
}
att := &connectAttempt{
ssid: req.SSID,
netPath: networkPath,
start: time.Now(),
deadline: time.Now().Add(15 * time.Second),
}
b.attemptMutex.Lock()
b.curAttempt = att
b.attemptMutex.Unlock()
b.stateMutex.Lock()
b.state.IsConnecting = true
b.state.ConnectingSSID = req.SSID
b.state.LastError = ""
b.stateMutex.Unlock()
if b.onStateChange != nil {
b.onStateChange()
}
netObj := b.conn.Object(iwdBusName, networkPath)
go func() {
call := netObj.Call(iwdNetworkInterface+".Connect", 0)
if call.Err != nil {
var code string
if dbusErr, ok := call.Err.(dbus.Error); ok {
code = b.mapIwdDBusError(dbusErr.Name)
} else if dbusErrPtr, ok := call.Err.(*dbus.Error); ok {
code = b.mapIwdDBusError(dbusErrPtr.Name)
} else {
code = errdefs.ErrConnectionFailed
}
att.mu.Lock()
if att.sawPromptRetry {
code = errdefs.ErrBadCredentials
}
att.mu.Unlock()
b.finalizeAttempt(att, code)
return
}
b.startAttemptWatchdog(att)
}()
return nil
}
func (b *IWDBackend) findNetworkPath(ssid string) (dbus.ObjectPath, error) {
obj := b.conn.Object(iwdBusName, iwdObjectPath)
var objects map[dbus.ObjectPath]map[string]map[string]dbus.Variant
err := obj.Call(dbusObjectManager+".GetManagedObjects", 0).Store(&objects)
if err != nil {
return "", err
}
for path, interfaces := range objects {
if netProps, ok := interfaces[iwdNetworkInterface]; ok {
if nameVar, ok := netProps["Name"]; ok {
if name, ok := nameVar.Value().(string); ok && name == ssid {
return path, nil
}
}
}
}
return "", fmt.Errorf("network not found")
}
func (b *IWDBackend) DisconnectWiFi() error {
if b.stationPath == "" {
return fmt.Errorf("no WiFi device available")
}
obj := b.conn.Object(iwdBusName, b.stationPath)
call := obj.Call(iwdStationInterface+".Disconnect", 0)
if call.Err != nil {
return fmt.Errorf("failed to disconnect: %w", call.Err)
}
b.updateState()
if b.onStateChange != nil {
b.onStateChange()
}
return nil
}
func (b *IWDBackend) ForgetWiFiNetwork(ssid string) error {
b.stateMutex.RLock()
currentSSID := b.state.WiFiSSID
isConnected := b.state.WiFiConnected
b.stateMutex.RUnlock()
obj := b.conn.Object(iwdBusName, iwdObjectPath)
var objects map[dbus.ObjectPath]map[string]map[string]dbus.Variant
err := obj.Call(dbusObjectManager+".GetManagedObjects", 0).Store(&objects)
if err != nil {
return err
}
for path, interfaces := range objects {
if knownProps, ok := interfaces[iwdKnownNetworkInterface]; ok {
if nameVar, ok := knownProps["Name"]; ok {
if name, ok := nameVar.Value().(string); ok && name == ssid {
knownObj := b.conn.Object(iwdBusName, path)
call := knownObj.Call(iwdKnownNetworkInterface+".Forget", 0)
if call.Err != nil {
return fmt.Errorf("failed to forget network: %w", call.Err)
}
if isConnected && currentSSID == ssid {
b.stateMutex.Lock()
b.state.WiFiConnected = false
b.state.WiFiSSID = ""
b.state.WiFiSignal = 0
b.state.WiFiIP = ""
b.state.NetworkStatus = StatusDisconnected
b.stateMutex.Unlock()
}
if b.onStateChange != nil {
b.onStateChange()
}
return nil
}
}
}
}
return fmt.Errorf("network not found")
}
func (b *IWDBackend) SetWiFiAutoconnect(ssid string, autoconnect bool) error {
obj := b.conn.Object(iwdBusName, iwdObjectPath)
var objects map[dbus.ObjectPath]map[string]map[string]dbus.Variant
err := obj.Call(dbusObjectManager+".GetManagedObjects", 0).Store(&objects)
if err != nil {
return err
}
for path, interfaces := range objects {
if knownProps, ok := interfaces[iwdKnownNetworkInterface]; ok {
if nameVar, ok := knownProps["Name"]; ok {
if name, ok := nameVar.Value().(string); ok && name == ssid {
knownObj := b.conn.Object(iwdBusName, path)
call := knownObj.Call(dbusPropertiesInterface+".Set", 0, iwdKnownNetworkInterface, "AutoConnect", dbus.MakeVariant(autoconnect))
if call.Err != nil {
return fmt.Errorf("failed to set autoconnect: %w", call.Err)
}
b.updateState()
if b.onStateChange != nil {
b.onStateChange()
}
return nil
}
}
}
}
return fmt.Errorf("network not found")
}

View File

@@ -0,0 +1,268 @@
package network
import (
"fmt"
"net"
"strings"
"sync"
"github.com/AvengeMedia/danklinux/internal/log"
"github.com/godbus/dbus/v5"
)
const (
networkdBusName = "org.freedesktop.network1"
networkdManagerPath = "/org/freedesktop/network1"
networkdManagerIface = "org.freedesktop.network1.Manager"
networkdLinkIface = "org.freedesktop.network1.Link"
)
type linkInfo struct {
ifindex int32
name string
path dbus.ObjectPath
opState string
}
type SystemdNetworkdBackend struct {
conn *dbus.Conn
managerPath dbus.ObjectPath
links map[string]*linkInfo
linksMutex sync.RWMutex
state *BackendState
stateMutex sync.RWMutex
onStateChange func()
stopChan chan struct{}
signals chan *dbus.Signal
sigWG sync.WaitGroup
}
func NewSystemdNetworkdBackend() (*SystemdNetworkdBackend, error) {
return &SystemdNetworkdBackend{
managerPath: networkdManagerPath,
links: make(map[string]*linkInfo),
state: &BackendState{
Backend: "networkd",
WiFiNetworks: []WiFiNetwork{},
},
stopChan: make(chan struct{}),
}, nil
}
func (b *SystemdNetworkdBackend) Initialize() error {
c, err := dbus.ConnectSystemBus()
if err != nil {
return fmt.Errorf("connect bus: %w", err)
}
b.conn = c
if err := b.enumerateLinks(); err != nil {
c.Close()
return fmt.Errorf("enumerate links: %w", err)
}
if err := b.updateState(); err != nil {
c.Close()
return fmt.Errorf("update initial state: %w", err)
}
return nil
}
func (b *SystemdNetworkdBackend) Close() {
close(b.stopChan)
b.StopMonitoring()
if b.conn != nil {
b.conn.Close()
}
}
func (b *SystemdNetworkdBackend) enumerateLinks() error {
obj := b.conn.Object(networkdBusName, b.managerPath)
var links []struct {
Ifindex int32
Name string
Path dbus.ObjectPath
}
err := obj.Call(networkdManagerIface+".ListLinks", 0).Store(&links)
if err != nil {
return fmt.Errorf("ListLinks: %w", err)
}
b.linksMutex.Lock()
defer b.linksMutex.Unlock()
for _, l := range links {
b.links[l.Name] = &linkInfo{
ifindex: l.Ifindex,
name: l.Name,
path: l.Path,
}
log.Debugf("networkd: enumerated link %s (ifindex=%d, path=%s)", l.Name, l.Ifindex, l.Path)
}
return nil
}
func (b *SystemdNetworkdBackend) updateState() error {
b.linksMutex.RLock()
defer b.linksMutex.RUnlock()
var wiredIface *linkInfo
var wifiIface *linkInfo
for name, link := range b.links {
if b.isVirtualInterface(name) {
continue
}
linkObj := b.conn.Object(networkdBusName, link.path)
opStateVar, err := linkObj.GetProperty(networkdLinkIface + ".OperationalState")
if err == nil {
if opState, ok := opStateVar.Value().(string); ok {
link.opState = opState
}
}
if strings.HasPrefix(name, "wlan") || strings.HasPrefix(name, "wlp") {
if wifiIface == nil || link.opState == "routable" || link.opState == "carrier" {
wifiIface = link
}
} else if !b.isVirtualInterface(name) {
if wiredIface == nil || link.opState == "routable" || link.opState == "carrier" {
wiredIface = link
}
}
}
var wiredConns []WiredConnection
for name, link := range b.links {
if b.isVirtualInterface(name) || strings.HasPrefix(name, "wlan") || strings.HasPrefix(name, "wlp") {
continue
}
active := link.opState == "routable" || link.opState == "carrier"
wiredConns = append(wiredConns, WiredConnection{
Path: link.path,
ID: name,
UUID: "wired:" + name,
Type: "ethernet",
IsActive: active,
})
}
b.stateMutex.Lock()
defer b.stateMutex.Unlock()
b.state.NetworkStatus = StatusDisconnected
b.state.EthernetConnected = false
b.state.EthernetIP = ""
b.state.WiFiConnected = false
b.state.WiFiIP = ""
b.state.WiredConnections = wiredConns
if wiredIface != nil {
b.state.EthernetDevice = wiredIface.name
log.Debugf("networkd: wired interface %s opState=%s", wiredIface.name, wiredIface.opState)
if wiredIface.opState == "routable" || wiredIface.opState == "carrier" {
b.state.EthernetConnected = true
b.state.NetworkStatus = StatusEthernet
if addrs := b.getAddresses(wiredIface.name); len(addrs) > 0 {
b.state.EthernetIP = addrs[0]
log.Debugf("networkd: ethernet IP %s on %s", addrs[0], wiredIface.name)
}
}
}
if wifiIface != nil {
b.state.WiFiDevice = wifiIface.name
log.Debugf("networkd: wifi interface %s opState=%s", wifiIface.name, wifiIface.opState)
if wifiIface.opState == "routable" || wifiIface.opState == "carrier" {
b.state.WiFiConnected = true
if addrs := b.getAddresses(wifiIface.name); len(addrs) > 0 {
b.state.WiFiIP = addrs[0]
log.Debugf("networkd: wifi IP %s on %s", addrs[0], wifiIface.name)
if b.state.NetworkStatus == StatusDisconnected {
b.state.NetworkStatus = StatusWiFi
}
}
}
}
return nil
}
func (b *SystemdNetworkdBackend) isVirtualInterface(name string) bool {
virtualPrefixes := []string{
"lo", "docker", "veth", "virbr", "br-", "vnet", "tun", "tap",
"vboxnet", "vmnet", "kube", "cni", "flannel", "cali",
}
for _, prefix := range virtualPrefixes {
if strings.HasPrefix(name, prefix) {
return true
}
}
return false
}
func (b *SystemdNetworkdBackend) getAddresses(ifname string) []string {
iface, err := net.InterfaceByName(ifname)
if err != nil {
return nil
}
addrs, err := iface.Addrs()
if err != nil {
return nil
}
var result []string
for _, addr := range addrs {
if ipnet, ok := addr.(*net.IPNet); ok {
if ipv4 := ipnet.IP.To4(); ipv4 != nil {
result = append(result, ipv4.String())
}
}
}
return result
}
func (b *SystemdNetworkdBackend) GetCurrentState() (*BackendState, error) {
b.stateMutex.RLock()
defer b.stateMutex.RUnlock()
s := *b.state
return &s, nil
}
func (b *SystemdNetworkdBackend) GetPromptBroker() PromptBroker {
return nil
}
func (b *SystemdNetworkdBackend) SetPromptBroker(broker PromptBroker) error {
return nil
}
func (b *SystemdNetworkdBackend) SubmitCredentials(token string, secrets map[string]string, save bool) error {
return fmt.Errorf("credentials not needed by networkd backend")
}
func (b *SystemdNetworkdBackend) CancelCredentials(token string) error {
return fmt.Errorf("credentials not needed by networkd backend")
}
func (b *SystemdNetworkdBackend) EnsureDhcpUp(ifname string) error {
b.linksMutex.RLock()
link, exists := b.links[ifname]
b.linksMutex.RUnlock()
if !exists {
return fmt.Errorf("interface %s not found", ifname)
}
linkObj := b.conn.Object(networkdBusName, link.path)
return linkObj.Call(networkdLinkIface+".Reconfigure", 0).Err
}

View File

@@ -0,0 +1,110 @@
package network
import (
"fmt"
"net"
"strings"
)
func (b *SystemdNetworkdBackend) GetWiredConnections() ([]WiredConnection, error) {
b.linksMutex.RLock()
defer b.linksMutex.RUnlock()
var conns []WiredConnection
for name, link := range b.links {
if b.isVirtualInterface(name) || strings.HasPrefix(name, "wlan") || strings.HasPrefix(name, "wlp") {
continue
}
active := link.opState == "routable" || link.opState == "carrier"
conns = append(conns, WiredConnection{
Path: link.path,
ID: name,
UUID: "wired:" + name,
Type: "ethernet",
IsActive: active,
})
}
return conns, nil
}
func (b *SystemdNetworkdBackend) GetWiredNetworkDetails(id string) (*WiredNetworkInfoResponse, error) {
ifname := strings.TrimPrefix(id, "wired:")
b.linksMutex.RLock()
_, exists := b.links[ifname]
b.linksMutex.RUnlock()
if !exists {
return nil, fmt.Errorf("interface %s not found", ifname)
}
iface, err := net.InterfaceByName(ifname)
if err != nil {
return nil, fmt.Errorf("get interface: %w", err)
}
addrs, _ := iface.Addrs()
var ipv4s, ipv6s []string
for _, addr := range addrs {
if ipnet, ok := addr.(*net.IPNet); ok {
if ipv4 := ipnet.IP.To4(); ipv4 != nil {
ipv4s = append(ipv4s, ipnet.String())
} else if ipv6 := ipnet.IP.To16(); ipv6 != nil {
ipv6s = append(ipv6s, ipnet.String())
}
}
}
return &WiredNetworkInfoResponse{
UUID: id,
IFace: ifname,
HwAddr: iface.HardwareAddr.String(),
IPv4: WiredIPConfig{
IPs: ipv4s,
},
IPv6: WiredIPConfig{
IPs: ipv6s,
},
}, nil
}
func (b *SystemdNetworkdBackend) ConnectEthernet() error {
b.linksMutex.RLock()
var primaryWired *linkInfo
for name, l := range b.links {
if strings.HasPrefix(name, "lo") || strings.HasPrefix(name, "wlan") || strings.HasPrefix(name, "wlp") {
continue
}
primaryWired = l
break
}
b.linksMutex.RUnlock()
if primaryWired == nil {
return fmt.Errorf("no wired interface found")
}
linkObj := b.conn.Object(networkdBusName, primaryWired.path)
return linkObj.Call(networkdLinkIface+".Reconfigure", 0).Err
}
func (b *SystemdNetworkdBackend) DisconnectEthernet() error {
return fmt.Errorf("not supported by networkd backend")
}
func (b *SystemdNetworkdBackend) ActivateWiredConnection(id string) error {
ifname := strings.TrimPrefix(id, "wired:")
b.linksMutex.RLock()
link, exists := b.links[ifname]
b.linksMutex.RUnlock()
if !exists {
return fmt.Errorf("interface %s not found", ifname)
}
linkObj := b.conn.Object(networkdBusName, link.path)
return linkObj.Call(networkdLinkIface+".Reconfigure", 0).Err
}

View File

@@ -0,0 +1,68 @@
package network
import (
"fmt"
"github.com/AvengeMedia/danklinux/internal/log"
"github.com/godbus/dbus/v5"
)
func (b *SystemdNetworkdBackend) StartMonitoring(onStateChange func()) error {
b.onStateChange = onStateChange
b.signals = make(chan *dbus.Signal, 64)
b.conn.Signal(b.signals)
matchRules := []string{
"type='signal',interface='org.freedesktop.DBus.Properties',member='PropertiesChanged',path_namespace='/org/freedesktop/network1'",
"type='signal',interface='org.freedesktop.network1.Manager'",
}
for _, rule := range matchRules {
if err := b.conn.BusObject().Call("org.freedesktop.DBus.AddMatch", 0, rule).Err; err != nil {
return fmt.Errorf("add match %q: %w", rule, err)
}
}
b.sigWG.Add(1)
go b.signalLoop()
return nil
}
func (b *SystemdNetworkdBackend) StopMonitoring() {
b.sigWG.Wait()
}
func (b *SystemdNetworkdBackend) signalLoop() {
defer b.sigWG.Done()
for {
select {
case <-b.stopChan:
return
case sig := <-b.signals:
if sig == nil {
continue
}
if sig.Name == "org.freedesktop.DBus.Properties.PropertiesChanged" {
if len(sig.Body) < 2 {
continue
}
iface, ok := sig.Body[0].(string)
if !ok || iface != networkdLinkIface {
continue
}
b.enumerateLinks()
if err := b.updateState(); err != nil {
log.Warnf("networkd state update failed: %v", err)
}
if b.onStateChange != nil {
b.onStateChange()
}
}
}
}
}

View File

@@ -0,0 +1,125 @@
package network
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestSystemdNetworkdBackend_New(t *testing.T) {
backend, err := NewSystemdNetworkdBackend()
assert.NoError(t, err)
assert.NotNil(t, backend)
assert.Equal(t, "networkd", backend.state.Backend)
assert.NotNil(t, backend.links)
assert.NotNil(t, backend.stopChan)
}
func TestSystemdNetworkdBackend_GetCurrentState(t *testing.T) {
backend, _ := NewSystemdNetworkdBackend()
backend.state.NetworkStatus = StatusEthernet
backend.state.EthernetConnected = true
backend.state.EthernetIP = "192.168.1.100"
state, err := backend.GetCurrentState()
assert.NoError(t, err)
assert.NotNil(t, state)
assert.Equal(t, StatusEthernet, state.NetworkStatus)
assert.True(t, state.EthernetConnected)
assert.Equal(t, "192.168.1.100", state.EthernetIP)
}
func TestSystemdNetworkdBackend_WiFiNotSupported(t *testing.T) {
backend, _ := NewSystemdNetworkdBackend()
err := backend.ScanWiFi()
assert.Error(t, err)
assert.Contains(t, err.Error(), "not supported")
req := ConnectionRequest{SSID: "test"}
err = backend.ConnectWiFi(req)
assert.Error(t, err)
assert.Contains(t, err.Error(), "not supported")
err = backend.DisconnectWiFi()
assert.Error(t, err)
assert.Contains(t, err.Error(), "not supported")
err = backend.ForgetWiFiNetwork("test")
assert.Error(t, err)
assert.Contains(t, err.Error(), "not supported")
_, err = backend.GetWiFiNetworkDetails("test")
assert.Error(t, err)
assert.Contains(t, err.Error(), "not supported")
}
func TestSystemdNetworkdBackend_VPNNotSupported(t *testing.T) {
backend, _ := NewSystemdNetworkdBackend()
profiles, err := backend.ListVPNProfiles()
assert.NoError(t, err)
assert.Empty(t, profiles)
active, err := backend.ListActiveVPN()
assert.NoError(t, err)
assert.Empty(t, active)
err = backend.ConnectVPN("test", false)
assert.Error(t, err)
assert.Contains(t, err.Error(), "not supported")
err = backend.DisconnectVPN("test")
assert.Error(t, err)
assert.Contains(t, err.Error(), "not supported")
err = backend.DisconnectAllVPN()
assert.Error(t, err)
assert.Contains(t, err.Error(), "not supported")
err = backend.ClearVPNCredentials("test")
assert.Error(t, err)
assert.Contains(t, err.Error(), "not supported")
}
func TestSystemdNetworkdBackend_PromptBroker(t *testing.T) {
backend, _ := NewSystemdNetworkdBackend()
broker := backend.GetPromptBroker()
assert.Nil(t, broker)
err := backend.SetPromptBroker(nil)
assert.NoError(t, err)
err = backend.SubmitCredentials("token", nil, false)
assert.Error(t, err)
assert.Contains(t, err.Error(), "not needed")
err = backend.CancelCredentials("token")
assert.Error(t, err)
assert.Contains(t, err.Error(), "not needed")
}
func TestSystemdNetworkdBackend_GetWiFiEnabled(t *testing.T) {
backend, _ := NewSystemdNetworkdBackend()
enabled, err := backend.GetWiFiEnabled()
assert.NoError(t, err)
assert.True(t, enabled)
}
func TestSystemdNetworkdBackend_SetWiFiEnabled(t *testing.T) {
backend, _ := NewSystemdNetworkdBackend()
err := backend.SetWiFiEnabled(false)
assert.Error(t, err)
assert.Contains(t, err.Error(), "not supported")
}
func TestSystemdNetworkdBackend_DisconnectEthernet(t *testing.T) {
backend, _ := NewSystemdNetworkdBackend()
err := backend.DisconnectEthernet()
assert.Error(t, err)
assert.Contains(t, err.Error(), "not supported")
}

View File

@@ -0,0 +1,59 @@
package network
import "fmt"
func (b *SystemdNetworkdBackend) GetWiFiEnabled() (bool, error) {
return true, nil
}
func (b *SystemdNetworkdBackend) SetWiFiEnabled(enabled bool) error {
return fmt.Errorf("WiFi control not supported by networkd backend")
}
func (b *SystemdNetworkdBackend) ScanWiFi() error {
return fmt.Errorf("WiFi scan not supported by networkd backend")
}
func (b *SystemdNetworkdBackend) GetWiFiNetworkDetails(ssid string) (*NetworkInfoResponse, error) {
return nil, fmt.Errorf("WiFi details not supported by networkd backend")
}
func (b *SystemdNetworkdBackend) ConnectWiFi(req ConnectionRequest) error {
return fmt.Errorf("WiFi connect not supported by networkd backend")
}
func (b *SystemdNetworkdBackend) DisconnectWiFi() error {
return fmt.Errorf("WiFi disconnect not supported by networkd backend")
}
func (b *SystemdNetworkdBackend) ForgetWiFiNetwork(ssid string) error {
return fmt.Errorf("WiFi forget not supported by networkd backend")
}
func (b *SystemdNetworkdBackend) ListVPNProfiles() ([]VPNProfile, error) {
return []VPNProfile{}, nil
}
func (b *SystemdNetworkdBackend) ListActiveVPN() ([]VPNActive, error) {
return []VPNActive{}, nil
}
func (b *SystemdNetworkdBackend) ConnectVPN(uuidOrName string, singleActive bool) error {
return fmt.Errorf("VPN not supported by networkd backend")
}
func (b *SystemdNetworkdBackend) DisconnectVPN(uuidOrName string) error {
return fmt.Errorf("VPN not supported by networkd backend")
}
func (b *SystemdNetworkdBackend) DisconnectAllVPN() error {
return fmt.Errorf("VPN not supported by networkd backend")
}
func (b *SystemdNetworkdBackend) ClearVPNCredentials(uuidOrName string) error {
return fmt.Errorf("VPN not supported by networkd backend")
}
func (b *SystemdNetworkdBackend) SetWiFiAutoconnect(ssid string, autoconnect bool) error {
return fmt.Errorf("WiFi autoconnect not supported by networkd backend")
}

View File

@@ -0,0 +1,307 @@
package network
import (
"fmt"
"sync"
"github.com/AvengeMedia/danklinux/internal/log"
"github.com/Wifx/gonetworkmanager/v2"
"github.com/godbus/dbus/v5"
)
const (
dbusNMPath = "/org/freedesktop/NetworkManager"
dbusNMInterface = "org.freedesktop.NetworkManager"
dbusNMDeviceInterface = "org.freedesktop.NetworkManager.Device"
dbusNMWirelessInterface = "org.freedesktop.NetworkManager.Device.Wireless"
dbusNMAccessPointInterface = "org.freedesktop.NetworkManager.AccessPoint"
dbusPropsInterface = "org.freedesktop.DBus.Properties"
NmDeviceStateReasonWrongPassword = 8
NmDeviceStateReasonSupplicantTimeout = 24
NmDeviceStateReasonSupplicantFailed = 25
NmDeviceStateReasonSecretsRequired = 7
NmDeviceStateReasonNoSecrets = 6
NmDeviceStateReasonNoSsid = 10
NmDeviceStateReasonDhcpClientFailed = 14
NmDeviceStateReasonIpConfigUnavailable = 18
NmDeviceStateReasonSupplicantDisconnect = 23
NmDeviceStateReasonCarrier = 40
NmDeviceStateReasonNewActivation = 60
)
type NetworkManagerBackend struct {
nmConn interface{}
ethernetDevice interface{}
wifiDevice interface{}
settings interface{}
wifiDev interface{}
dbusConn *dbus.Conn
signals chan *dbus.Signal
sigWG sync.WaitGroup
stopChan chan struct{}
secretAgent *SecretAgent
promptBroker PromptBroker
state *BackendState
stateMutex sync.RWMutex
lastFailedSSID string
lastFailedTime int64
failedMutex sync.RWMutex
onStateChange func()
}
func NewNetworkManagerBackend(nmConn ...gonetworkmanager.NetworkManager) (*NetworkManagerBackend, error) {
var nm gonetworkmanager.NetworkManager
var err error
if len(nmConn) > 0 && nmConn[0] != nil {
// Use injected connection (for testing)
nm = nmConn[0]
} else {
// Create real connection
nm, err = gonetworkmanager.NewNetworkManager()
if err != nil {
return nil, fmt.Errorf("failed to connect to NetworkManager: %w", err)
}
}
backend := &NetworkManagerBackend{
nmConn: nm,
stopChan: make(chan struct{}),
state: &BackendState{
Backend: "networkmanager",
},
}
return backend, nil
}
func (b *NetworkManagerBackend) Initialize() error {
nm := b.nmConn.(gonetworkmanager.NetworkManager)
if s, err := gonetworkmanager.NewSettings(); err == nil {
b.settings = s
}
devices, err := nm.GetDevices()
if err != nil {
return fmt.Errorf("failed to get devices: %w", err)
}
for _, dev := range devices {
devType, err := dev.GetPropertyDeviceType()
if err != nil {
continue
}
switch devType {
case gonetworkmanager.NmDeviceTypeEthernet:
if managed, _ := dev.GetPropertyManaged(); !managed {
continue
}
b.ethernetDevice = dev
if err := b.updateEthernetState(); err != nil {
continue
}
_, err := b.listEthernetConnections()
if err != nil {
return fmt.Errorf("failed to get wired configurations: %w", err)
}
case gonetworkmanager.NmDeviceTypeWifi:
b.wifiDevice = dev
if w, err := gonetworkmanager.NewDeviceWireless(dev.GetPath()); err == nil {
b.wifiDev = w
}
wifiEnabled, err := nm.GetPropertyWirelessEnabled()
if err == nil {
b.stateMutex.Lock()
b.state.WiFiEnabled = wifiEnabled
b.stateMutex.Unlock()
}
if err := b.updateWiFiState(); err != nil {
continue
}
if wifiEnabled {
if _, err := b.updateWiFiNetworks(); err != nil {
log.Warnf("Failed to get initial networks: %v", err)
}
}
}
}
if err := b.updatePrimaryConnection(); err != nil {
return err
}
if _, err := b.ListVPNProfiles(); err != nil {
log.Warnf("Failed to get initial VPN profiles: %v", err)
}
if _, err := b.ListActiveVPN(); err != nil {
log.Warnf("Failed to get initial active VPNs: %v", err)
}
return nil
}
func (b *NetworkManagerBackend) Close() {
close(b.stopChan)
b.StopMonitoring()
if b.secretAgent != nil {
b.secretAgent.Close()
}
}
func (b *NetworkManagerBackend) GetCurrentState() (*BackendState, error) {
b.stateMutex.RLock()
defer b.stateMutex.RUnlock()
state := *b.state
state.WiFiNetworks = append([]WiFiNetwork(nil), b.state.WiFiNetworks...)
state.WiredConnections = append([]WiredConnection(nil), b.state.WiredConnections...)
state.VPNProfiles = append([]VPNProfile(nil), b.state.VPNProfiles...)
state.VPNActive = append([]VPNActive(nil), b.state.VPNActive...)
return &state, nil
}
func (b *NetworkManagerBackend) StartMonitoring(onStateChange func()) error {
b.onStateChange = onStateChange
if err := b.startSecretAgent(); err != nil {
return fmt.Errorf("failed to start secret agent: %w", err)
}
if err := b.startSignalPump(); err != nil {
return err
}
return nil
}
func (b *NetworkManagerBackend) StopMonitoring() {
b.stopSignalPump()
}
func (b *NetworkManagerBackend) GetPromptBroker() PromptBroker {
return b.promptBroker
}
func (b *NetworkManagerBackend) SetPromptBroker(broker PromptBroker) error {
if broker == nil {
return fmt.Errorf("broker cannot be nil")
}
hadAgent := b.secretAgent != nil
b.promptBroker = broker
if b.secretAgent != nil {
b.secretAgent.Close()
b.secretAgent = nil
}
if hadAgent {
return b.startSecretAgent()
}
return nil
}
func (b *NetworkManagerBackend) SubmitCredentials(token string, secrets map[string]string, save bool) error {
if b.promptBroker == nil {
return fmt.Errorf("prompt broker not initialized")
}
return b.promptBroker.Resolve(token, PromptReply{
Secrets: secrets,
Save: save,
Cancel: false,
})
}
func (b *NetworkManagerBackend) CancelCredentials(token string) error {
if b.promptBroker == nil {
return fmt.Errorf("prompt broker not initialized")
}
return b.promptBroker.Resolve(token, PromptReply{
Cancel: true,
})
}
func (b *NetworkManagerBackend) ensureWiFiDevice() error {
if b.wifiDev != nil {
return nil
}
if b.wifiDevice == nil {
return fmt.Errorf("no WiFi device available")
}
dev := b.wifiDevice.(gonetworkmanager.Device)
wifiDev, err := gonetworkmanager.NewDeviceWireless(dev.GetPath())
if err != nil {
return fmt.Errorf("failed to get wireless device: %w", err)
}
b.wifiDev = wifiDev
return nil
}
func (b *NetworkManagerBackend) startSecretAgent() error {
if b.promptBroker == nil {
return fmt.Errorf("prompt broker not set")
}
agent, err := NewSecretAgent(b.promptBroker, nil, b)
if err != nil {
return err
}
b.secretAgent = agent
return nil
}
func (b *NetworkManagerBackend) getActiveConnections() (map[string]bool, error) {
nm := b.nmConn.(gonetworkmanager.NetworkManager)
activeUUIDs := make(map[string]bool)
activeConns, err := nm.GetPropertyActiveConnections()
if err != nil {
return activeUUIDs, fmt.Errorf("failed to get active connections: %w", err)
}
for _, activeConn := range activeConns {
connType, err := activeConn.GetPropertyType()
if err != nil {
continue
}
if connType != "802-3-ethernet" {
continue
}
state, err := activeConn.GetPropertyState()
if err != nil {
continue
}
if state < 1 || state > 2 {
continue
}
uuid, err := activeConn.GetPropertyUUID()
if err != nil {
continue
}
activeUUIDs[uuid] = true
}
return activeUUIDs, nil
}

View File

@@ -0,0 +1,317 @@
package network
import (
"fmt"
"net"
"strconv"
"strings"
"github.com/AvengeMedia/danklinux/internal/log"
"github.com/Wifx/gonetworkmanager/v2"
)
func (b *NetworkManagerBackend) GetWiredConnections() ([]WiredConnection, error) {
return b.listEthernetConnections()
}
func (b *NetworkManagerBackend) GetWiredNetworkDetails(uuid string) (*WiredNetworkInfoResponse, error) {
if b.ethernetDevice == nil {
return nil, fmt.Errorf("no ethernet device available")
}
dev := b.ethernetDevice.(gonetworkmanager.Device)
iface, _ := dev.GetPropertyInterface()
driver, _ := dev.GetPropertyDriver()
hwAddr := "Not available"
var speed uint32 = 0
wiredDevice, err := gonetworkmanager.NewDeviceWired(dev.GetPath())
if err == nil {
hwAddr, _ = wiredDevice.GetPropertyHwAddress()
speed, _ = wiredDevice.GetPropertySpeed()
}
var ipv4Config WiredIPConfig
var ipv6Config WiredIPConfig
activeConn, err := dev.GetPropertyActiveConnection()
if err == nil && activeConn != nil {
ip4Config, err := activeConn.GetPropertyIP4Config()
if err == nil && ip4Config != nil {
var ips []string
addresses, err := ip4Config.GetPropertyAddressData()
if err == nil && len(addresses) > 0 {
for _, addr := range addresses {
ips = append(ips, fmt.Sprintf("%s/%s", addr.Address, strconv.Itoa(int(addr.Prefix))))
}
}
gateway, _ := ip4Config.GetPropertyGateway()
dnsAddrs := ""
dns, err := ip4Config.GetPropertyNameserverData()
if err == nil && len(dns) > 0 {
for _, d := range dns {
if len(dnsAddrs) > 0 {
dnsAddrs = strings.Join([]string{dnsAddrs, d.Address}, "; ")
} else {
dnsAddrs = d.Address
}
}
}
ipv4Config = WiredIPConfig{
IPs: ips,
Gateway: gateway,
DNS: dnsAddrs,
}
}
ip6Config, err := activeConn.GetPropertyIP6Config()
if err == nil && ip6Config != nil {
var ips []string
addresses, err := ip6Config.GetPropertyAddressData()
if err == nil && len(addresses) > 0 {
for _, addr := range addresses {
ips = append(ips, fmt.Sprintf("%s/%s", addr.Address, strconv.Itoa(int(addr.Prefix))))
}
}
gateway, _ := ip6Config.GetPropertyGateway()
dnsAddrs := ""
dns, err := ip6Config.GetPropertyNameservers()
if err == nil && len(dns) > 0 {
for _, d := range dns {
if len(d) == 16 {
ip := net.IP(d)
if len(dnsAddrs) > 0 {
dnsAddrs = strings.Join([]string{dnsAddrs, ip.String()}, "; ")
} else {
dnsAddrs = ip.String()
}
}
}
}
ipv6Config = WiredIPConfig{
IPs: ips,
Gateway: gateway,
DNS: dnsAddrs,
}
}
}
return &WiredNetworkInfoResponse{
UUID: uuid,
IFace: iface,
Driver: driver,
HwAddr: hwAddr,
Speed: strconv.Itoa(int(speed)),
IPv4: ipv4Config,
IPv6: ipv6Config,
}, nil
}
func (b *NetworkManagerBackend) ConnectEthernet() error {
if b.ethernetDevice == nil {
return fmt.Errorf("no ethernet device available")
}
nm := b.nmConn.(gonetworkmanager.NetworkManager)
dev := b.ethernetDevice.(gonetworkmanager.Device)
settingsMgr, err := gonetworkmanager.NewSettings()
if err != nil {
return fmt.Errorf("failed to get settings: %w", err)
}
connections, err := settingsMgr.ListConnections()
if err != nil {
return fmt.Errorf("failed to get connections: %w", err)
}
for _, conn := range connections {
connSettings, err := conn.GetSettings()
if err != nil {
continue
}
if connMeta, ok := connSettings["connection"]; ok {
if connType, ok := connMeta["type"].(string); ok && connType == "802-3-ethernet" {
_, err := nm.ActivateConnection(conn, dev, nil)
if err != nil {
return fmt.Errorf("failed to activate ethernet: %w", err)
}
b.updateEthernetState()
b.listEthernetConnections()
b.updatePrimaryConnection()
if b.onStateChange != nil {
b.onStateChange()
}
return nil
}
}
}
settings := make(map[string]map[string]interface{})
settings["connection"] = map[string]interface{}{
"id": "Wired connection",
"type": "802-3-ethernet",
}
_, err = nm.AddAndActivateConnection(settings, dev)
if err != nil {
return fmt.Errorf("failed to create and activate ethernet: %w", err)
}
b.updateEthernetState()
b.listEthernetConnections()
b.updatePrimaryConnection()
if b.onStateChange != nil {
b.onStateChange()
}
return nil
}
func (b *NetworkManagerBackend) DisconnectEthernet() error {
if b.ethernetDevice == nil {
return fmt.Errorf("no ethernet device available")
}
dev := b.ethernetDevice.(gonetworkmanager.Device)
err := dev.Disconnect()
if err != nil {
return fmt.Errorf("failed to disconnect: %w", err)
}
b.updateEthernetState()
b.listEthernetConnections()
b.updatePrimaryConnection()
if b.onStateChange != nil {
b.onStateChange()
}
return nil
}
func (b *NetworkManagerBackend) ActivateWiredConnection(uuid string) error {
if b.ethernetDevice == nil {
return fmt.Errorf("no ethernet device available")
}
nm := b.nmConn.(gonetworkmanager.NetworkManager)
dev := b.ethernetDevice.(gonetworkmanager.Device)
settingsMgr, err := gonetworkmanager.NewSettings()
if err != nil {
return fmt.Errorf("failed to get settings: %w", err)
}
connections, err := settingsMgr.ListConnections()
if err != nil {
return fmt.Errorf("failed to get connections: %w", err)
}
var targetConnection gonetworkmanager.Connection
for _, conn := range connections {
settings, err := conn.GetSettings()
if err != nil {
continue
}
if connectionSettings, ok := settings["connection"]; ok {
if connUUID, ok := connectionSettings["uuid"].(string); ok && connUUID == uuid {
targetConnection = conn
break
}
}
}
if targetConnection == nil {
return fmt.Errorf("connection with UUID %s not found", uuid)
}
_, err = nm.ActivateConnection(targetConnection, dev, nil)
if err != nil {
return fmt.Errorf("error activation connection: %w", err)
}
b.updateEthernetState()
b.listEthernetConnections()
b.updatePrimaryConnection()
if b.onStateChange != nil {
b.onStateChange()
}
return nil
}
func (b *NetworkManagerBackend) listEthernetConnections() ([]WiredConnection, error) {
if b.ethernetDevice == nil {
return nil, fmt.Errorf("no ethernet device available")
}
s := b.settings
if s == nil {
s, err := gonetworkmanager.NewSettings()
if err != nil {
return nil, fmt.Errorf("failed to get settings: %w", err)
}
b.settings = s
}
settingsMgr := s.(gonetworkmanager.Settings)
connections, err := settingsMgr.ListConnections()
if err != nil {
return nil, fmt.Errorf("failed to get connections: %w", err)
}
wiredConfigs := make([]WiredConnection, 0)
activeUUIDs, err := b.getActiveConnections()
if err != nil {
return nil, fmt.Errorf("failed to get active wired connections: %w", err)
}
currentUuid := ""
for _, connection := range connections {
path := connection.GetPath()
settings, err := connection.GetSettings()
if err != nil {
log.Errorf("unable to get settings for %s: %v", path, err)
continue
}
connectionSettings := settings["connection"]
connType, _ := connectionSettings["type"].(string)
connID, _ := connectionSettings["id"].(string)
connUUID, _ := connectionSettings["uuid"].(string)
if connType == "802-3-ethernet" {
wiredConfigs = append(wiredConfigs, WiredConnection{
Path: path,
ID: connID,
UUID: connUUID,
Type: connType,
IsActive: activeUUIDs[connUUID],
})
if activeUUIDs[connUUID] {
currentUuid = connUUID
}
}
}
b.stateMutex.Lock()
b.state.EthernetConnectionUuid = currentUuid
b.state.WiredConnections = wiredConfigs
b.stateMutex.Unlock()
return wiredConfigs, nil
}

View File

@@ -0,0 +1,94 @@
package network
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestNetworkManagerBackend_GetWiredConnections_NoDevice(t *testing.T) {
backend, err := NewNetworkManagerBackend()
if err != nil {
t.Skipf("NetworkManager not available: %v", err)
}
backend.ethernetDevice = nil
_, err = backend.GetWiredConnections()
assert.Error(t, err)
assert.Contains(t, err.Error(), "no ethernet device available")
}
func TestNetworkManagerBackend_GetWiredNetworkDetails_NoDevice(t *testing.T) {
backend, err := NewNetworkManagerBackend()
if err != nil {
t.Skipf("NetworkManager not available: %v", err)
}
backend.ethernetDevice = nil
_, err = backend.GetWiredNetworkDetails("test-uuid")
assert.Error(t, err)
assert.Contains(t, err.Error(), "no ethernet device available")
}
func TestNetworkManagerBackend_ConnectEthernet_NoDevice(t *testing.T) {
backend, err := NewNetworkManagerBackend()
if err != nil {
t.Skipf("NetworkManager not available: %v", err)
}
backend.ethernetDevice = nil
err = backend.ConnectEthernet()
assert.Error(t, err)
assert.Contains(t, err.Error(), "no ethernet device available")
}
func TestNetworkManagerBackend_DisconnectEthernet_NoDevice(t *testing.T) {
backend, err := NewNetworkManagerBackend()
if err != nil {
t.Skipf("NetworkManager not available: %v", err)
}
backend.ethernetDevice = nil
err = backend.DisconnectEthernet()
assert.Error(t, err)
assert.Contains(t, err.Error(), "no ethernet device available")
}
func TestNetworkManagerBackend_ActivateWiredConnection_NoDevice(t *testing.T) {
backend, err := NewNetworkManagerBackend()
if err != nil {
t.Skipf("NetworkManager not available: %v", err)
}
backend.ethernetDevice = nil
err = backend.ActivateWiredConnection("test-uuid")
assert.Error(t, err)
assert.Contains(t, err.Error(), "no ethernet device available")
}
func TestNetworkManagerBackend_ActivateWiredConnection_NotFound(t *testing.T) {
backend, err := NewNetworkManagerBackend()
if err != nil {
t.Skipf("NetworkManager not available: %v", err)
}
if backend.ethernetDevice == nil {
t.Skip("No ethernet device available")
}
err = backend.ActivateWiredConnection("non-existent-uuid-12345")
assert.Error(t, err)
assert.Contains(t, err.Error(), "not found")
}
func TestNetworkManagerBackend_ListEthernetConnections_NoDevice(t *testing.T) {
backend, err := NewNetworkManagerBackend()
if err != nil {
t.Skipf("NetworkManager not available: %v", err)
}
backend.ethernetDevice = nil
_, err = backend.listEthernetConnections()
assert.Error(t, err)
assert.Contains(t, err.Error(), "no ethernet device available")
}

View File

@@ -0,0 +1,321 @@
package network
import (
"github.com/Wifx/gonetworkmanager/v2"
"github.com/godbus/dbus/v5"
)
func (b *NetworkManagerBackend) startSignalPump() error {
conn, err := dbus.ConnectSystemBus()
if err != nil {
return err
}
b.dbusConn = conn
signals := make(chan *dbus.Signal, 256)
b.signals = signals
conn.Signal(signals)
if err := conn.AddMatchSignal(
dbus.WithMatchObjectPath(dbus.ObjectPath(dbusNMPath)),
dbus.WithMatchInterface(dbusPropsInterface),
dbus.WithMatchMember("PropertiesChanged"),
); err != nil {
conn.RemoveSignal(signals)
conn.Close()
return err
}
if err := conn.AddMatchSignal(
dbus.WithMatchObjectPath(dbus.ObjectPath("/org/freedesktop/NetworkManager/Settings")),
dbus.WithMatchInterface("org.freedesktop.NetworkManager.Settings"),
dbus.WithMatchMember("NewConnection"),
); err != nil {
conn.RemoveMatchSignal(
dbus.WithMatchObjectPath(dbus.ObjectPath(dbusNMPath)),
dbus.WithMatchInterface(dbusPropsInterface),
dbus.WithMatchMember("PropertiesChanged"),
)
conn.RemoveSignal(signals)
conn.Close()
return err
}
if err := conn.AddMatchSignal(
dbus.WithMatchObjectPath(dbus.ObjectPath("/org/freedesktop/NetworkManager/Settings")),
dbus.WithMatchInterface("org.freedesktop.NetworkManager.Settings"),
dbus.WithMatchMember("ConnectionRemoved"),
); err != nil {
conn.RemoveMatchSignal(
dbus.WithMatchObjectPath(dbus.ObjectPath(dbusNMPath)),
dbus.WithMatchInterface(dbusPropsInterface),
dbus.WithMatchMember("PropertiesChanged"),
)
conn.RemoveMatchSignal(
dbus.WithMatchObjectPath(dbus.ObjectPath("/org/freedesktop/NetworkManager/Settings")),
dbus.WithMatchInterface("org.freedesktop.NetworkManager.Settings"),
dbus.WithMatchMember("NewConnection"),
)
conn.RemoveSignal(signals)
conn.Close()
return err
}
if b.wifiDevice != nil {
dev := b.wifiDevice.(gonetworkmanager.Device)
if err := conn.AddMatchSignal(
dbus.WithMatchObjectPath(dbus.ObjectPath(dev.GetPath())),
dbus.WithMatchInterface(dbusPropsInterface),
dbus.WithMatchMember("PropertiesChanged"),
); err != nil {
conn.RemoveMatchSignal(
dbus.WithMatchObjectPath(dbus.ObjectPath(dbusNMPath)),
dbus.WithMatchInterface(dbusPropsInterface),
dbus.WithMatchMember("PropertiesChanged"),
)
conn.RemoveSignal(signals)
conn.Close()
return err
}
}
if b.ethernetDevice != nil {
dev := b.ethernetDevice.(gonetworkmanager.Device)
if err := conn.AddMatchSignal(
dbus.WithMatchObjectPath(dbus.ObjectPath(dev.GetPath())),
dbus.WithMatchInterface(dbusPropsInterface),
dbus.WithMatchMember("PropertiesChanged"),
); err != nil {
conn.RemoveMatchSignal(
dbus.WithMatchObjectPath(dbus.ObjectPath(dbusNMPath)),
dbus.WithMatchInterface(dbusPropsInterface),
dbus.WithMatchMember("PropertiesChanged"),
)
if b.wifiDevice != nil {
dev := b.wifiDevice.(gonetworkmanager.Device)
conn.RemoveMatchSignal(
dbus.WithMatchObjectPath(dbus.ObjectPath(dev.GetPath())),
dbus.WithMatchInterface(dbusPropsInterface),
dbus.WithMatchMember("PropertiesChanged"),
)
}
conn.RemoveSignal(signals)
conn.Close()
return err
}
}
b.sigWG.Add(1)
go func() {
defer b.sigWG.Done()
for {
select {
case <-b.stopChan:
return
case sig, ok := <-signals:
if !ok {
return
}
if sig == nil {
continue
}
b.handleDBusSignal(sig)
}
}
}()
return nil
}
func (b *NetworkManagerBackend) stopSignalPump() {
if b.dbusConn == nil {
return
}
b.dbusConn.RemoveMatchSignal(
dbus.WithMatchObjectPath(dbus.ObjectPath(dbusNMPath)),
dbus.WithMatchInterface(dbusPropsInterface),
dbus.WithMatchMember("PropertiesChanged"),
)
if b.wifiDevice != nil {
dev := b.wifiDevice.(gonetworkmanager.Device)
b.dbusConn.RemoveMatchSignal(
dbus.WithMatchObjectPath(dbus.ObjectPath(dev.GetPath())),
dbus.WithMatchInterface(dbusPropsInterface),
dbus.WithMatchMember("PropertiesChanged"),
)
}
if b.ethernetDevice != nil {
dev := b.ethernetDevice.(gonetworkmanager.Device)
b.dbusConn.RemoveMatchSignal(
dbus.WithMatchObjectPath(dbus.ObjectPath(dev.GetPath())),
dbus.WithMatchInterface(dbusPropsInterface),
dbus.WithMatchMember("PropertiesChanged"),
)
}
if b.signals != nil {
b.dbusConn.RemoveSignal(b.signals)
close(b.signals)
}
b.sigWG.Wait()
b.dbusConn.Close()
}
func (b *NetworkManagerBackend) handleDBusSignal(sig *dbus.Signal) {
if sig.Name == "org.freedesktop.NetworkManager.Settings.NewConnection" ||
sig.Name == "org.freedesktop.NetworkManager.Settings.ConnectionRemoved" {
b.ListVPNProfiles()
if b.onStateChange != nil {
b.onStateChange()
}
return
}
if len(sig.Body) < 2 {
return
}
iface, ok := sig.Body[0].(string)
if !ok {
return
}
changes, ok := sig.Body[1].(map[string]dbus.Variant)
if !ok {
return
}
switch iface {
case dbusNMInterface:
b.handleNetworkManagerChange(changes)
case dbusNMDeviceInterface:
b.handleDeviceChange(changes)
case dbusNMWirelessInterface:
b.handleWiFiChange(changes)
case dbusNMAccessPointInterface:
b.handleAccessPointChange(changes)
}
}
func (b *NetworkManagerBackend) handleNetworkManagerChange(changes map[string]dbus.Variant) {
var needsUpdate bool
for key := range changes {
switch key {
case "PrimaryConnection", "State", "ActiveConnections":
needsUpdate = true
case "WirelessEnabled":
nm := b.nmConn.(gonetworkmanager.NetworkManager)
if enabled, err := nm.GetPropertyWirelessEnabled(); err == nil {
b.stateMutex.Lock()
b.state.WiFiEnabled = enabled
b.stateMutex.Unlock()
needsUpdate = true
}
default:
continue
}
}
if needsUpdate {
b.updatePrimaryConnection()
if _, exists := changes["State"]; exists {
b.updateEthernetState()
b.updateWiFiState()
}
if _, exists := changes["ActiveConnections"]; exists {
b.updateVPNConnectionState()
b.ListActiveVPN()
}
if b.onStateChange != nil {
b.onStateChange()
}
}
}
func (b *NetworkManagerBackend) handleDeviceChange(changes map[string]dbus.Variant) {
var needsUpdate bool
var stateChanged bool
for key := range changes {
switch key {
case "State":
stateChanged = true
needsUpdate = true
case "Ip4Config":
needsUpdate = true
default:
continue
}
}
if needsUpdate {
b.updateEthernetState()
b.updateWiFiState()
if stateChanged {
b.updatePrimaryConnection()
}
if b.onStateChange != nil {
b.onStateChange()
}
}
}
func (b *NetworkManagerBackend) handleWiFiChange(changes map[string]dbus.Variant) {
var needsStateUpdate bool
var needsNetworkUpdate bool
for key := range changes {
switch key {
case "ActiveAccessPoint":
needsStateUpdate = true
needsNetworkUpdate = true
case "AccessPoints":
needsNetworkUpdate = true
default:
continue
}
}
if needsStateUpdate {
b.updateWiFiState()
}
if needsNetworkUpdate {
b.updateWiFiNetworks()
}
if needsStateUpdate || needsNetworkUpdate {
if b.onStateChange != nil {
b.onStateChange()
}
}
}
func (b *NetworkManagerBackend) handleAccessPointChange(changes map[string]dbus.Variant) {
_, hasStrength := changes["Strength"]
if !hasStrength {
return
}
b.stateMutex.RLock()
oldSignal := b.state.WiFiSignal
b.stateMutex.RUnlock()
b.updateWiFiState()
b.stateMutex.RLock()
newSignal := b.state.WiFiSignal
b.stateMutex.RUnlock()
if signalChangeSignificant(oldSignal, newSignal) {
if b.onStateChange != nil {
b.onStateChange()
}
}
}

View File

@@ -0,0 +1,240 @@
package network
import (
"testing"
"github.com/godbus/dbus/v5"
"github.com/stretchr/testify/assert"
)
func TestNetworkManagerBackend_HandleDBusSignal_NewConnection(t *testing.T) {
backend, err := NewNetworkManagerBackend()
if err != nil {
t.Skipf("NetworkManager not available: %v", err)
}
sig := &dbus.Signal{
Name: "org.freedesktop.NetworkManager.Settings.NewConnection",
Body: []interface{}{"/org/freedesktop/NetworkManager/Settings/1"},
}
assert.NotPanics(t, func() {
backend.handleDBusSignal(sig)
})
}
func TestNetworkManagerBackend_HandleDBusSignal_ConnectionRemoved(t *testing.T) {
backend, err := NewNetworkManagerBackend()
if err != nil {
t.Skipf("NetworkManager not available: %v", err)
}
sig := &dbus.Signal{
Name: "org.freedesktop.NetworkManager.Settings.ConnectionRemoved",
Body: []interface{}{"/org/freedesktop/NetworkManager/Settings/1"},
}
assert.NotPanics(t, func() {
backend.handleDBusSignal(sig)
})
}
func TestNetworkManagerBackend_HandleDBusSignal_InvalidBody(t *testing.T) {
backend, err := NewNetworkManagerBackend()
if err != nil {
t.Skipf("NetworkManager not available: %v", err)
}
sig := &dbus.Signal{
Name: "org.freedesktop.DBus.Properties.PropertiesChanged",
Body: []interface{}{"only-one-element"},
}
assert.NotPanics(t, func() {
backend.handleDBusSignal(sig)
})
}
func TestNetworkManagerBackend_HandleDBusSignal_InvalidInterface(t *testing.T) {
backend, err := NewNetworkManagerBackend()
if err != nil {
t.Skipf("NetworkManager not available: %v", err)
}
sig := &dbus.Signal{
Name: "org.freedesktop.DBus.Properties.PropertiesChanged",
Body: []interface{}{123, map[string]dbus.Variant{}},
}
assert.NotPanics(t, func() {
backend.handleDBusSignal(sig)
})
}
func TestNetworkManagerBackend_HandleDBusSignal_InvalidChanges(t *testing.T) {
backend, err := NewNetworkManagerBackend()
if err != nil {
t.Skipf("NetworkManager not available: %v", err)
}
sig := &dbus.Signal{
Name: "org.freedesktop.DBus.Properties.PropertiesChanged",
Body: []interface{}{dbusNMInterface, "not-a-map"},
}
assert.NotPanics(t, func() {
backend.handleDBusSignal(sig)
})
}
func TestNetworkManagerBackend_HandleNetworkManagerChange(t *testing.T) {
backend, err := NewNetworkManagerBackend()
if err != nil {
t.Skipf("NetworkManager not available: %v", err)
}
changes := map[string]dbus.Variant{
"PrimaryConnection": dbus.MakeVariant("/"),
"State": dbus.MakeVariant(uint32(70)),
}
assert.NotPanics(t, func() {
backend.handleNetworkManagerChange(changes)
})
}
func TestNetworkManagerBackend_HandleNetworkManagerChange_WirelessEnabled(t *testing.T) {
backend, err := NewNetworkManagerBackend()
if err != nil {
t.Skipf("NetworkManager not available: %v", err)
}
changes := map[string]dbus.Variant{
"WirelessEnabled": dbus.MakeVariant(true),
}
assert.NotPanics(t, func() {
backend.handleNetworkManagerChange(changes)
})
}
func TestNetworkManagerBackend_HandleNetworkManagerChange_ActiveConnections(t *testing.T) {
backend, err := NewNetworkManagerBackend()
if err != nil {
t.Skipf("NetworkManager not available: %v", err)
}
changes := map[string]dbus.Variant{
"ActiveConnections": dbus.MakeVariant([]interface{}{}),
}
assert.NotPanics(t, func() {
backend.handleNetworkManagerChange(changes)
})
}
func TestNetworkManagerBackend_HandleDeviceChange(t *testing.T) {
backend, err := NewNetworkManagerBackend()
if err != nil {
t.Skipf("NetworkManager not available: %v", err)
}
changes := map[string]dbus.Variant{
"State": dbus.MakeVariant(uint32(100)),
}
assert.NotPanics(t, func() {
backend.handleDeviceChange(changes)
})
}
func TestNetworkManagerBackend_HandleDeviceChange_Ip4Config(t *testing.T) {
backend, err := NewNetworkManagerBackend()
if err != nil {
t.Skipf("NetworkManager not available: %v", err)
}
changes := map[string]dbus.Variant{
"Ip4Config": dbus.MakeVariant("/"),
}
assert.NotPanics(t, func() {
backend.handleDeviceChange(changes)
})
}
func TestNetworkManagerBackend_HandleWiFiChange_ActiveAccessPoint(t *testing.T) {
backend, err := NewNetworkManagerBackend()
if err != nil {
t.Skipf("NetworkManager not available: %v", err)
}
changes := map[string]dbus.Variant{
"ActiveAccessPoint": dbus.MakeVariant("/"),
}
assert.NotPanics(t, func() {
backend.handleWiFiChange(changes)
})
}
func TestNetworkManagerBackend_HandleWiFiChange_AccessPoints(t *testing.T) {
backend, err := NewNetworkManagerBackend()
if err != nil {
t.Skipf("NetworkManager not available: %v", err)
}
changes := map[string]dbus.Variant{
"AccessPoints": dbus.MakeVariant([]interface{}{}),
}
assert.NotPanics(t, func() {
backend.handleWiFiChange(changes)
})
}
func TestNetworkManagerBackend_HandleAccessPointChange_NoStrength(t *testing.T) {
backend, err := NewNetworkManagerBackend()
if err != nil {
t.Skipf("NetworkManager not available: %v", err)
}
changes := map[string]dbus.Variant{
"SomeOtherProperty": dbus.MakeVariant("value"),
}
assert.NotPanics(t, func() {
backend.handleAccessPointChange(changes)
})
}
func TestNetworkManagerBackend_HandleAccessPointChange_WithStrength(t *testing.T) {
backend, err := NewNetworkManagerBackend()
if err != nil {
t.Skipf("NetworkManager not available: %v", err)
}
backend.stateMutex.Lock()
backend.state.WiFiSignal = 50
backend.stateMutex.Unlock()
changes := map[string]dbus.Variant{
"Strength": dbus.MakeVariant(uint8(80)),
}
assert.NotPanics(t, func() {
backend.handleAccessPointChange(changes)
})
}
func TestNetworkManagerBackend_StopSignalPump_NoConnection(t *testing.T) {
backend, err := NewNetworkManagerBackend()
if err != nil {
t.Skipf("NetworkManager not available: %v", err)
}
backend.dbusConn = nil
assert.NotPanics(t, func() {
backend.stopSignalPump()
})
}

View File

@@ -0,0 +1,271 @@
package network
import (
"time"
"github.com/AvengeMedia/danklinux/internal/errdefs"
"github.com/AvengeMedia/danklinux/internal/log"
"github.com/Wifx/gonetworkmanager/v2"
)
func (b *NetworkManagerBackend) updatePrimaryConnection() error {
nm := b.nmConn.(gonetworkmanager.NetworkManager)
activeConns, err := nm.GetPropertyActiveConnections()
if err != nil {
return err
}
hasActiveVPN := false
for _, activeConn := range activeConns {
connType, err := activeConn.GetPropertyType()
if err != nil {
continue
}
if connType == "vpn" || connType == "wireguard" {
state, _ := activeConn.GetPropertyState()
if state == 2 {
hasActiveVPN = true
break
}
}
}
if hasActiveVPN {
b.stateMutex.Lock()
b.state.NetworkStatus = StatusVPN
b.stateMutex.Unlock()
return nil
}
primaryConn, err := nm.GetPropertyPrimaryConnection()
if err != nil {
return err
}
if primaryConn == nil || primaryConn.GetPath() == "/" {
b.stateMutex.Lock()
b.state.NetworkStatus = StatusDisconnected
b.stateMutex.Unlock()
return nil
}
connType, err := primaryConn.GetPropertyType()
if err != nil {
return err
}
b.stateMutex.Lock()
switch connType {
case "802-3-ethernet":
b.state.NetworkStatus = StatusEthernet
case "802-11-wireless":
b.state.NetworkStatus = StatusWiFi
case "vpn", "wireguard":
b.state.NetworkStatus = StatusVPN
default:
b.state.NetworkStatus = StatusDisconnected
}
b.stateMutex.Unlock()
return nil
}
func (b *NetworkManagerBackend) updateEthernetState() error {
if b.ethernetDevice == nil {
return nil
}
dev := b.ethernetDevice.(gonetworkmanager.Device)
iface, err := dev.GetPropertyInterface()
if err != nil {
return err
}
state, err := dev.GetPropertyState()
if err != nil {
return err
}
connected := state == gonetworkmanager.NmDeviceStateActivated
var ip string
if connected {
ip = b.getDeviceIP(dev)
}
b.stateMutex.Lock()
b.state.EthernetDevice = iface
b.state.EthernetConnected = connected
b.state.EthernetIP = ip
b.stateMutex.Unlock()
return nil
}
func (b *NetworkManagerBackend) getDeviceStateReason(dev gonetworkmanager.Device) uint32 {
path := dev.GetPath()
obj := b.dbusConn.Object("org.freedesktop.NetworkManager", path)
variant, err := obj.GetProperty(dbusNMDeviceInterface + ".StateReason")
if err != nil {
return 0
}
if stateReasonStruct, ok := variant.Value().([]interface{}); ok && len(stateReasonStruct) >= 2 {
if reason, ok := stateReasonStruct[1].(uint32); ok {
return reason
}
}
return 0
}
func (b *NetworkManagerBackend) classifyNMStateReason(reason uint32) string {
switch reason {
case NmDeviceStateReasonWrongPassword,
NmDeviceStateReasonSupplicantTimeout,
NmDeviceStateReasonSupplicantFailed,
NmDeviceStateReasonSecretsRequired:
return errdefs.ErrBadCredentials
case NmDeviceStateReasonNoSecrets:
return errdefs.ErrUserCanceled
case NmDeviceStateReasonNoSsid:
return errdefs.ErrNoSuchSSID
case NmDeviceStateReasonDhcpClientFailed,
NmDeviceStateReasonIpConfigUnavailable:
return errdefs.ErrDhcpTimeout
case NmDeviceStateReasonSupplicantDisconnect,
NmDeviceStateReasonCarrier:
return errdefs.ErrAssocTimeout
default:
return errdefs.ErrConnectionFailed
}
}
func (b *NetworkManagerBackend) updateWiFiState() error {
if b.wifiDevice == nil {
return nil
}
dev := b.wifiDevice.(gonetworkmanager.Device)
iface, err := dev.GetPropertyInterface()
if err != nil {
return err
}
state, err := dev.GetPropertyState()
if err != nil {
return err
}
connected := state == gonetworkmanager.NmDeviceStateActivated
failed := state == gonetworkmanager.NmDeviceStateFailed
disconnected := state == gonetworkmanager.NmDeviceStateDisconnected
var ip, ssid, bssid string
var signal uint8
if connected {
if err := b.ensureWiFiDevice(); err == nil && b.wifiDev != nil {
w := b.wifiDev.(gonetworkmanager.DeviceWireless)
activeAP, err := w.GetPropertyActiveAccessPoint()
if err == nil && activeAP != nil && activeAP.GetPath() != "/" {
ssid, _ = activeAP.GetPropertySSID()
signal, _ = activeAP.GetPropertyStrength()
bssid, _ = activeAP.GetPropertyHWAddress()
}
}
ip = b.getDeviceIP(dev)
}
b.stateMutex.RLock()
wasConnecting := b.state.IsConnecting
connectingSSID := b.state.ConnectingSSID
b.stateMutex.RUnlock()
var reasonCode string
if wasConnecting && connectingSSID != "" && (failed || (disconnected && !connected)) {
reason := b.getDeviceStateReason(dev)
if reason == NmDeviceStateReasonNewActivation || reason == 0 {
return nil
}
log.Warnf("[updateWiFiState] Connection failed: SSID=%s, state=%d, reason=%d", connectingSSID, state, reason)
reasonCode = b.classifyNMStateReason(reason)
if reasonCode == errdefs.ErrConnectionFailed {
b.failedMutex.RLock()
if b.lastFailedSSID == connectingSSID {
elapsed := time.Now().Unix() - b.lastFailedTime
if elapsed < 5 {
reasonCode = errdefs.ErrBadCredentials
}
}
b.failedMutex.RUnlock()
}
}
b.stateMutex.Lock()
defer b.stateMutex.Unlock()
wasConnecting = b.state.IsConnecting
connectingSSID = b.state.ConnectingSSID
if wasConnecting && connectingSSID != "" {
if connected && ssid == connectingSSID {
log.Infof("[updateWiFiState] Connection successful: %s", ssid)
b.state.IsConnecting = false
b.state.ConnectingSSID = ""
b.state.LastError = ""
} else if failed || (disconnected && !connected) {
log.Warnf("[updateWiFiState] Connection failed: SSID=%s, state=%d", connectingSSID, state)
b.state.IsConnecting = false
b.state.ConnectingSSID = ""
b.state.LastError = reasonCode
// If user cancelled, delete the connection profile that was just created
if reasonCode == errdefs.ErrUserCanceled {
log.Infof("[updateWiFiState] User cancelled authentication, removing connection profile for %s", connectingSSID)
b.stateMutex.Unlock()
if err := b.ForgetWiFiNetwork(connectingSSID); err != nil {
log.Warnf("[updateWiFiState] Failed to remove cancelled connection: %v", err)
}
b.stateMutex.Lock()
}
b.failedMutex.Lock()
b.lastFailedSSID = connectingSSID
b.lastFailedTime = time.Now().Unix()
b.failedMutex.Unlock()
}
}
b.state.WiFiDevice = iface
b.state.WiFiConnected = connected
b.state.WiFiIP = ip
b.state.WiFiSSID = ssid
b.state.WiFiBSSID = bssid
b.state.WiFiSignal = signal
return nil
}
func (b *NetworkManagerBackend) getDeviceIP(dev gonetworkmanager.Device) string {
ip4Config, err := dev.GetPropertyIP4Config()
if err != nil || ip4Config == nil {
return ""
}
addresses, err := ip4Config.GetPropertyAddressData()
if err != nil || len(addresses) == 0 {
return ""
}
return addresses[0].Address
}

View File

@@ -0,0 +1,82 @@
package network
import (
"testing"
"github.com/AvengeMedia/danklinux/internal/errdefs"
mock_gonetworkmanager "github.com/AvengeMedia/danklinux/internal/mocks/github.com/Wifx/gonetworkmanager/v2"
"github.com/Wifx/gonetworkmanager/v2"
"github.com/stretchr/testify/assert"
)
func TestNetworkManagerBackend_UpdatePrimaryConnection(t *testing.T) {
mockNM := mock_gonetworkmanager.NewMockNetworkManager(t)
backend, err := NewNetworkManagerBackend(mockNM)
assert.NoError(t, err)
mockNM.EXPECT().GetPropertyActiveConnections().Return([]gonetworkmanager.ActiveConnection{}, nil)
mockNM.EXPECT().GetPropertyPrimaryConnection().Return(nil, nil)
err = backend.updatePrimaryConnection()
assert.NoError(t, err)
}
func TestNetworkManagerBackend_UpdateEthernetState_NoDevice(t *testing.T) {
mockNM := mock_gonetworkmanager.NewMockNetworkManager(t)
backend, err := NewNetworkManagerBackend(mockNM)
assert.NoError(t, err)
backend.ethernetDevice = nil
err = backend.updateEthernetState()
assert.NoError(t, err)
}
func TestNetworkManagerBackend_UpdateWiFiState_NoDevice(t *testing.T) {
mockNM := mock_gonetworkmanager.NewMockNetworkManager(t)
backend, err := NewNetworkManagerBackend(mockNM)
assert.NoError(t, err)
backend.wifiDevice = nil
err = backend.updateWiFiState()
assert.NoError(t, err)
}
func TestNetworkManagerBackend_ClassifyNMStateReason(t *testing.T) {
mockNM := mock_gonetworkmanager.NewMockNetworkManager(t)
backend, err := NewNetworkManagerBackend(mockNM)
assert.NoError(t, err)
testCases := []struct {
reason uint32
expected string
}{
{NmDeviceStateReasonWrongPassword, errdefs.ErrBadCredentials},
{NmDeviceStateReasonNoSecrets, errdefs.ErrUserCanceled},
{NmDeviceStateReasonSupplicantTimeout, errdefs.ErrBadCredentials},
{NmDeviceStateReasonDhcpClientFailed, errdefs.ErrDhcpTimeout},
{NmDeviceStateReasonNoSsid, errdefs.ErrNoSuchSSID},
{999, errdefs.ErrConnectionFailed},
}
for _, tc := range testCases {
result := backend.classifyNMStateReason(tc.reason)
assert.Equal(t, tc.expected, result)
}
}
func TestNetworkManagerBackend_GetDeviceIP_NoConfig(t *testing.T) {
mockNM := mock_gonetworkmanager.NewMockNetworkManager(t)
mockDevice := mock_gonetworkmanager.NewMockDevice(t)
backend, err := NewNetworkManagerBackend(mockNM)
assert.NoError(t, err)
mockDevice.EXPECT().GetPropertyIP4Config().Return(nil, nil)
ip := backend.getDeviceIP(mockDevice)
assert.Empty(t, ip)
}

View File

@@ -0,0 +1,154 @@
package network
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestNetworkManagerBackend_New(t *testing.T) {
backend, err := NewNetworkManagerBackend()
if err != nil {
t.Skipf("NetworkManager not available: %v", err)
}
assert.NotNil(t, backend)
assert.Equal(t, "networkmanager", backend.state.Backend)
assert.NotNil(t, backend.stopChan)
assert.NotNil(t, backend.state)
}
func TestNetworkManagerBackend_GetCurrentState(t *testing.T) {
backend, err := NewNetworkManagerBackend()
if err != nil {
t.Skipf("NetworkManager not available: %v", err)
}
backend.state.NetworkStatus = StatusWiFi
backend.state.WiFiConnected = true
backend.state.WiFiSSID = "TestNetwork"
backend.state.WiFiIP = "192.168.1.100"
backend.state.WiFiNetworks = []WiFiNetwork{
{SSID: "TestNetwork", Signal: 80, Connected: true},
}
backend.state.WiredConnections = []WiredConnection{
{ID: "Wired connection 1", UUID: "test-uuid"},
}
state, err := backend.GetCurrentState()
assert.NoError(t, err)
assert.NotNil(t, state)
assert.Equal(t, StatusWiFi, state.NetworkStatus)
assert.True(t, state.WiFiConnected)
assert.Equal(t, "TestNetwork", state.WiFiSSID)
assert.Len(t, state.WiFiNetworks, 1)
assert.Len(t, state.WiredConnections, 1)
assert.NotSame(t, &backend.state.WiFiNetworks, &state.WiFiNetworks)
assert.NotSame(t, &backend.state.WiredConnections, &state.WiredConnections)
}
func TestNetworkManagerBackend_SetPromptBroker_Nil(t *testing.T) {
backend, err := NewNetworkManagerBackend()
if err != nil {
t.Skipf("NetworkManager not available: %v", err)
}
err = backend.SetPromptBroker(nil)
assert.Error(t, err)
assert.Contains(t, err.Error(), "cannot be nil")
}
func TestNetworkManagerBackend_SubmitCredentials_NoBroker(t *testing.T) {
backend, err := NewNetworkManagerBackend()
if err != nil {
t.Skipf("NetworkManager not available: %v", err)
}
backend.promptBroker = nil
err = backend.SubmitCredentials("token", map[string]string{"password": "test"}, false)
assert.Error(t, err)
assert.Contains(t, err.Error(), "not initialized")
}
func TestNetworkManagerBackend_CancelCredentials_NoBroker(t *testing.T) {
backend, err := NewNetworkManagerBackend()
if err != nil {
t.Skipf("NetworkManager not available: %v", err)
}
backend.promptBroker = nil
err = backend.CancelCredentials("token")
assert.Error(t, err)
assert.Contains(t, err.Error(), "not initialized")
}
func TestNetworkManagerBackend_EnsureWiFiDevice_NoDevice(t *testing.T) {
backend, err := NewNetworkManagerBackend()
if err != nil {
t.Skipf("NetworkManager not available: %v", err)
}
backend.wifiDevice = nil
backend.wifiDev = nil
err = backend.ensureWiFiDevice()
assert.Error(t, err)
assert.Contains(t, err.Error(), "no WiFi device available")
}
func TestNetworkManagerBackend_EnsureWiFiDevice_AlreadySet(t *testing.T) {
backend, err := NewNetworkManagerBackend()
if err != nil {
t.Skipf("NetworkManager not available: %v", err)
}
backend.wifiDev = "dummy-device"
err = backend.ensureWiFiDevice()
assert.NoError(t, err)
}
func TestNetworkManagerBackend_StartSecretAgent_NoBroker(t *testing.T) {
backend, err := NewNetworkManagerBackend()
if err != nil {
t.Skipf("NetworkManager not available: %v", err)
}
backend.promptBroker = nil
err = backend.startSecretAgent()
assert.Error(t, err)
assert.Contains(t, err.Error(), "prompt broker not set")
}
func TestNetworkManagerBackend_Close(t *testing.T) {
backend, err := NewNetworkManagerBackend()
if err != nil {
t.Skipf("NetworkManager not available: %v", err)
}
assert.NotPanics(t, func() {
backend.Close()
})
}
func TestNetworkManagerBackend_GetPromptBroker(t *testing.T) {
backend, err := NewNetworkManagerBackend()
if err != nil {
t.Skipf("NetworkManager not available: %v", err)
}
broker := backend.GetPromptBroker()
assert.Nil(t, broker)
}
func TestNetworkManagerBackend_StopMonitoring_NoSignals(t *testing.T) {
backend, err := NewNetworkManagerBackend()
if err != nil {
t.Skipf("NetworkManager not available: %v", err)
}
assert.NotPanics(t, func() {
backend.StopMonitoring()
})
}

View File

@@ -0,0 +1,527 @@
package network
import (
"fmt"
"sort"
"strings"
"time"
"github.com/AvengeMedia/danklinux/internal/log"
"github.com/Wifx/gonetworkmanager/v2"
)
func (b *NetworkManagerBackend) ListVPNProfiles() ([]VPNProfile, error) {
s := b.settings
if s == nil {
var err error
s, err = gonetworkmanager.NewSettings()
if err != nil {
return nil, fmt.Errorf("failed to get settings: %w", err)
}
b.settings = s
}
settingsMgr := s.(gonetworkmanager.Settings)
connections, err := settingsMgr.ListConnections()
if err != nil {
return nil, fmt.Errorf("failed to get connections: %w", err)
}
var profiles []VPNProfile
for _, conn := range connections {
settings, err := conn.GetSettings()
if err != nil {
continue
}
connMeta, ok := settings["connection"]
if !ok {
continue
}
connType, _ := connMeta["type"].(string)
if connType != "vpn" && connType != "wireguard" {
continue
}
connID, _ := connMeta["id"].(string)
connUUID, _ := connMeta["uuid"].(string)
profile := VPNProfile{
Name: connID,
UUID: connUUID,
Type: connType,
}
if connType == "vpn" {
if vpnSettings, ok := settings["vpn"]; ok {
if svcType, ok := vpnSettings["service-type"].(string); ok {
profile.ServiceType = svcType
}
}
}
profiles = append(profiles, profile)
}
sort.Slice(profiles, func(i, j int) bool {
return strings.ToLower(profiles[i].Name) < strings.ToLower(profiles[j].Name)
})
b.stateMutex.Lock()
b.state.VPNProfiles = profiles
b.stateMutex.Unlock()
return profiles, nil
}
func (b *NetworkManagerBackend) ListActiveVPN() ([]VPNActive, error) {
nm := b.nmConn.(gonetworkmanager.NetworkManager)
activeConns, err := nm.GetPropertyActiveConnections()
if err != nil {
return nil, fmt.Errorf("failed to get active connections: %w", err)
}
var active []VPNActive
for _, activeConn := range activeConns {
connType, err := activeConn.GetPropertyType()
if err != nil {
continue
}
if connType != "vpn" && connType != "wireguard" {
continue
}
uuid, _ := activeConn.GetPropertyUUID()
id, _ := activeConn.GetPropertyID()
state, _ := activeConn.GetPropertyState()
var stateStr string
switch state {
case 0:
stateStr = "unknown"
case 1:
stateStr = "activating"
case 2:
stateStr = "activated"
case 3:
stateStr = "deactivating"
case 4:
stateStr = "deactivated"
}
vpnActive := VPNActive{
Name: id,
UUID: uuid,
State: stateStr,
Type: connType,
Plugin: "",
}
if connType == "vpn" {
conn, _ := activeConn.GetPropertyConnection()
if conn != nil {
connSettings, err := conn.GetSettings()
if err == nil {
if vpnSettings, ok := connSettings["vpn"]; ok {
if svcType, ok := vpnSettings["service-type"].(string); ok {
vpnActive.Plugin = svcType
}
}
}
}
}
active = append(active, vpnActive)
}
b.stateMutex.Lock()
b.state.VPNActive = active
b.stateMutex.Unlock()
return active, nil
}
func (b *NetworkManagerBackend) ConnectVPN(uuidOrName string, singleActive bool) error {
if singleActive {
active, err := b.ListActiveVPN()
if err == nil && len(active) > 0 {
alreadyConnected := false
for _, vpn := range active {
if vpn.UUID == uuidOrName || vpn.Name == uuidOrName {
alreadyConnected = true
break
}
}
if !alreadyConnected {
if err := b.DisconnectAllVPN(); err != nil {
log.Warnf("Failed to disconnect existing VPNs: %v", err)
}
time.Sleep(500 * time.Millisecond)
} else {
return nil
}
}
}
s := b.settings
if s == nil {
var err error
s, err = gonetworkmanager.NewSettings()
if err != nil {
return fmt.Errorf("failed to get settings: %w", err)
}
b.settings = s
}
settingsMgr := s.(gonetworkmanager.Settings)
connections, err := settingsMgr.ListConnections()
if err != nil {
return fmt.Errorf("failed to get connections: %w", err)
}
var targetConn gonetworkmanager.Connection
for _, conn := range connections {
settings, err := conn.GetSettings()
if err != nil {
continue
}
connMeta, ok := settings["connection"]
if !ok {
continue
}
connType, _ := connMeta["type"].(string)
if connType != "vpn" && connType != "wireguard" {
continue
}
connID, _ := connMeta["id"].(string)
connUUID, _ := connMeta["uuid"].(string)
if connUUID == uuidOrName || connID == uuidOrName {
targetConn = conn
break
}
}
if targetConn == nil {
return fmt.Errorf("VPN connection not found: %s", uuidOrName)
}
targetSettings, err := targetConn.GetSettings()
if err != nil {
return fmt.Errorf("failed to get connection settings: %w", err)
}
var targetUUID string
if connMeta, ok := targetSettings["connection"]; ok {
if uuid, ok := connMeta["uuid"].(string); ok {
targetUUID = uuid
}
}
b.stateMutex.Lock()
b.state.IsConnectingVPN = true
b.state.ConnectingVPNUUID = targetUUID
b.stateMutex.Unlock()
if b.onStateChange != nil {
b.onStateChange()
}
nm := b.nmConn.(gonetworkmanager.NetworkManager)
activeConn, err := nm.ActivateConnection(targetConn, nil, nil)
if err != nil {
b.stateMutex.Lock()
b.state.IsConnectingVPN = false
b.state.ConnectingVPNUUID = ""
b.stateMutex.Unlock()
if b.onStateChange != nil {
b.onStateChange()
}
return fmt.Errorf("failed to activate VPN: %w", err)
}
if activeConn != nil {
state, _ := activeConn.GetPropertyState()
if state == 2 {
b.stateMutex.Lock()
b.state.IsConnectingVPN = false
b.state.ConnectingVPNUUID = ""
b.stateMutex.Unlock()
b.ListActiveVPN()
if b.onStateChange != nil {
b.onStateChange()
}
}
}
return nil
}
func (b *NetworkManagerBackend) DisconnectVPN(uuidOrName string) error {
nm := b.nmConn.(gonetworkmanager.NetworkManager)
activeConns, err := nm.GetPropertyActiveConnections()
if err != nil {
return fmt.Errorf("failed to get active connections: %w", err)
}
log.Debugf("[DisconnectVPN] Looking for VPN: %s", uuidOrName)
for _, activeConn := range activeConns {
connType, err := activeConn.GetPropertyType()
if err != nil {
continue
}
if connType != "vpn" && connType != "wireguard" {
continue
}
uuid, _ := activeConn.GetPropertyUUID()
id, _ := activeConn.GetPropertyID()
state, _ := activeConn.GetPropertyState()
log.Debugf("[DisconnectVPN] Found active VPN: uuid=%s id=%s state=%d", uuid, id, state)
if uuid == uuidOrName || id == uuidOrName {
log.Infof("[DisconnectVPN] Deactivating VPN: %s (state=%d)", id, state)
if err := nm.DeactivateConnection(activeConn); err != nil {
return fmt.Errorf("failed to deactivate VPN: %w", err)
}
b.ListActiveVPN()
if b.onStateChange != nil {
b.onStateChange()
}
return nil
}
}
log.Warnf("[DisconnectVPN] VPN not found in active connections: %s", uuidOrName)
s := b.settings
if s == nil {
var err error
s, err = gonetworkmanager.NewSettings()
if err != nil {
return fmt.Errorf("VPN connection not active and cannot access settings: %w", err)
}
b.settings = s
}
settingsMgr := s.(gonetworkmanager.Settings)
connections, err := settingsMgr.ListConnections()
if err != nil {
return fmt.Errorf("VPN connection not active: %s", uuidOrName)
}
for _, conn := range connections {
settings, err := conn.GetSettings()
if err != nil {
continue
}
connMeta, ok := settings["connection"]
if !ok {
continue
}
connType, _ := connMeta["type"].(string)
if connType != "vpn" && connType != "wireguard" {
continue
}
connID, _ := connMeta["id"].(string)
connUUID, _ := connMeta["uuid"].(string)
if connUUID == uuidOrName || connID == uuidOrName {
log.Infof("[DisconnectVPN] VPN connection exists but not active: %s", connID)
return nil
}
}
return fmt.Errorf("VPN connection not found: %s", uuidOrName)
}
func (b *NetworkManagerBackend) DisconnectAllVPN() error {
nm := b.nmConn.(gonetworkmanager.NetworkManager)
activeConns, err := nm.GetPropertyActiveConnections()
if err != nil {
return fmt.Errorf("failed to get active connections: %w", err)
}
var lastErr error
var disconnected bool
for _, activeConn := range activeConns {
connType, err := activeConn.GetPropertyType()
if err != nil {
continue
}
if connType != "vpn" && connType != "wireguard" {
continue
}
if err := nm.DeactivateConnection(activeConn); err != nil {
lastErr = err
log.Warnf("Failed to deactivate VPN connection: %v", err)
} else {
disconnected = true
}
}
if disconnected {
b.ListActiveVPN()
if b.onStateChange != nil {
b.onStateChange()
}
}
return lastErr
}
func (b *NetworkManagerBackend) ClearVPNCredentials(uuidOrName string) error {
s := b.settings
if s == nil {
var err error
s, err = gonetworkmanager.NewSettings()
if err != nil {
return fmt.Errorf("failed to get settings: %w", err)
}
b.settings = s
}
settingsMgr := s.(gonetworkmanager.Settings)
connections, err := settingsMgr.ListConnections()
if err != nil {
return fmt.Errorf("failed to get connections: %w", err)
}
for _, conn := range connections {
settings, err := conn.GetSettings()
if err != nil {
continue
}
connMeta, ok := settings["connection"]
if !ok {
continue
}
connType, _ := connMeta["type"].(string)
if connType != "vpn" && connType != "wireguard" {
continue
}
connID, _ := connMeta["id"].(string)
connUUID, _ := connMeta["uuid"].(string)
if connUUID == uuidOrName || connID == uuidOrName {
if connType == "vpn" {
if vpnSettings, ok := settings["vpn"]; ok {
delete(vpnSettings, "secrets")
if dataMap, ok := vpnSettings["data"].(map[string]string); ok {
dataMap["password-flags"] = "1"
vpnSettings["data"] = dataMap
}
vpnSettings["password-flags"] = uint32(1)
}
settings["vpn-secrets"] = make(map[string]interface{})
}
if err := conn.Update(settings); err != nil {
return fmt.Errorf("failed to update connection: %w", err)
}
if err := conn.ClearSecrets(); err != nil {
log.Warnf("ClearSecrets call failed (may not be critical): %v", err)
}
log.Infof("Cleared credentials for VPN: %s", connID)
return nil
}
}
return fmt.Errorf("VPN connection not found: %s", uuidOrName)
}
func (b *NetworkManagerBackend) updateVPNConnectionState() {
b.stateMutex.RLock()
isConnectingVPN := b.state.IsConnectingVPN
connectingVPNUUID := b.state.ConnectingVPNUUID
b.stateMutex.RUnlock()
if !isConnectingVPN || connectingVPNUUID == "" {
return
}
nm := b.nmConn.(gonetworkmanager.NetworkManager)
activeConns, err := nm.GetPropertyActiveConnections()
if err != nil {
return
}
foundConnection := false
for _, activeConn := range activeConns {
connType, err := activeConn.GetPropertyType()
if err != nil {
continue
}
if connType != "vpn" && connType != "wireguard" {
continue
}
uuid, err := activeConn.GetPropertyUUID()
if err != nil {
continue
}
state, _ := activeConn.GetPropertyState()
stateReason, _ := activeConn.GetPropertyStateFlags()
if uuid == connectingVPNUUID {
foundConnection = true
switch state {
case 2:
log.Infof("[updateVPNConnectionState] VPN connection successful: %s", uuid)
b.stateMutex.Lock()
b.state.IsConnectingVPN = false
b.state.ConnectingVPNUUID = ""
b.state.LastError = ""
b.stateMutex.Unlock()
return
case 4:
log.Warnf("[updateVPNConnectionState] VPN connection failed/deactivated: %s (state=%d, flags=%d)", uuid, state, stateReason)
b.stateMutex.Lock()
b.state.IsConnectingVPN = false
b.state.ConnectingVPNUUID = ""
b.state.LastError = "VPN connection failed"
b.stateMutex.Unlock()
return
}
}
}
if !foundConnection {
log.Warnf("[updateVPNConnectionState] VPN connection no longer exists: %s", connectingVPNUUID)
b.stateMutex.Lock()
b.state.IsConnectingVPN = false
b.state.ConnectingVPNUUID = ""
b.state.LastError = "VPN connection failed"
b.stateMutex.Unlock()
}
}

View File

@@ -0,0 +1,138 @@
package network
import (
"testing"
mock_gonetworkmanager "github.com/AvengeMedia/danklinux/internal/mocks/github.com/Wifx/gonetworkmanager/v2"
"github.com/Wifx/gonetworkmanager/v2"
"github.com/stretchr/testify/assert"
)
func TestNetworkManagerBackend_ListVPNProfiles(t *testing.T) {
mockNM := mock_gonetworkmanager.NewMockNetworkManager(t)
mockSettings := mock_gonetworkmanager.NewMockSettings(t)
backend, err := NewNetworkManagerBackend(mockNM)
assert.NoError(t, err)
backend.settings = mockSettings
mockSettings.EXPECT().ListConnections().Return([]gonetworkmanager.Connection{}, nil)
profiles, err := backend.ListVPNProfiles()
assert.NoError(t, err)
assert.Empty(t, profiles)
}
func TestNetworkManagerBackend_ListActiveVPN(t *testing.T) {
mockNM := mock_gonetworkmanager.NewMockNetworkManager(t)
backend, err := NewNetworkManagerBackend(mockNM)
assert.NoError(t, err)
mockNM.EXPECT().GetPropertyActiveConnections().Return([]gonetworkmanager.ActiveConnection{}, nil)
active, err := backend.ListActiveVPN()
assert.NoError(t, err)
assert.Empty(t, active)
}
func TestNetworkManagerBackend_ConnectVPN_NotFound(t *testing.T) {
mockNM := mock_gonetworkmanager.NewMockNetworkManager(t)
mockSettings := mock_gonetworkmanager.NewMockSettings(t)
backend, err := NewNetworkManagerBackend(mockNM)
assert.NoError(t, err)
backend.settings = mockSettings
mockSettings.EXPECT().ListConnections().Return([]gonetworkmanager.Connection{}, nil)
err = backend.ConnectVPN("non-existent-vpn-12345", false)
assert.Error(t, err)
assert.Contains(t, err.Error(), "not found")
}
func TestNetworkManagerBackend_ConnectVPN_SingleActive_NoActiveVPN(t *testing.T) {
mockNM := mock_gonetworkmanager.NewMockNetworkManager(t)
mockSettings := mock_gonetworkmanager.NewMockSettings(t)
backend, err := NewNetworkManagerBackend(mockNM)
assert.NoError(t, err)
backend.settings = mockSettings
mockSettings.EXPECT().ListConnections().Return([]gonetworkmanager.Connection{}, nil)
mockNM.EXPECT().GetPropertyActiveConnections().Return([]gonetworkmanager.ActiveConnection{}, nil)
err = backend.ConnectVPN("non-existent-vpn-12345", true)
assert.Error(t, err)
}
func TestNetworkManagerBackend_DisconnectVPN_NotActive(t *testing.T) {
mockNM := mock_gonetworkmanager.NewMockNetworkManager(t)
backend, err := NewNetworkManagerBackend(mockNM)
assert.NoError(t, err)
mockNM.EXPECT().GetPropertyActiveConnections().Return([]gonetworkmanager.ActiveConnection{}, nil)
err = backend.DisconnectVPN("non-existent-vpn-12345")
assert.Error(t, err)
}
func TestNetworkManagerBackend_DisconnectAllVPN(t *testing.T) {
mockNM := mock_gonetworkmanager.NewMockNetworkManager(t)
backend, err := NewNetworkManagerBackend(mockNM)
assert.NoError(t, err)
mockNM.EXPECT().GetPropertyActiveConnections().Return([]gonetworkmanager.ActiveConnection{}, nil)
err = backend.DisconnectAllVPN()
assert.NoError(t, err)
}
func TestNetworkManagerBackend_ClearVPNCredentials_NotFound(t *testing.T) {
mockNM := mock_gonetworkmanager.NewMockNetworkManager(t)
mockSettings := mock_gonetworkmanager.NewMockSettings(t)
backend, err := NewNetworkManagerBackend(mockNM)
assert.NoError(t, err)
backend.settings = mockSettings
mockSettings.EXPECT().ListConnections().Return([]gonetworkmanager.Connection{}, nil)
err = backend.ClearVPNCredentials("non-existent-vpn-12345")
assert.Error(t, err)
assert.Contains(t, err.Error(), "not found")
}
func TestNetworkManagerBackend_UpdateVPNConnectionState_NotConnecting(t *testing.T) {
mockNM := mock_gonetworkmanager.NewMockNetworkManager(t)
backend, err := NewNetworkManagerBackend(mockNM)
assert.NoError(t, err)
backend.stateMutex.Lock()
backend.state.IsConnectingVPN = false
backend.state.ConnectingVPNUUID = ""
backend.stateMutex.Unlock()
assert.NotPanics(t, func() {
backend.updateVPNConnectionState()
})
}
func TestNetworkManagerBackend_UpdateVPNConnectionState_EmptyUUID(t *testing.T) {
mockNM := mock_gonetworkmanager.NewMockNetworkManager(t)
backend, err := NewNetworkManagerBackend(mockNM)
assert.NoError(t, err)
backend.stateMutex.Lock()
backend.state.IsConnectingVPN = true
backend.state.ConnectingVPNUUID = ""
backend.stateMutex.Unlock()
assert.NotPanics(t, func() {
backend.updateVPNConnectionState()
})
}

View File

@@ -0,0 +1,718 @@
package network
import (
"bytes"
"fmt"
"sort"
"github.com/AvengeMedia/danklinux/internal/log"
"github.com/Wifx/gonetworkmanager/v2"
)
func (b *NetworkManagerBackend) GetWiFiEnabled() (bool, error) {
nm := b.nmConn.(gonetworkmanager.NetworkManager)
return nm.GetPropertyWirelessEnabled()
}
func (b *NetworkManagerBackend) SetWiFiEnabled(enabled bool) error {
nm := b.nmConn.(gonetworkmanager.NetworkManager)
err := nm.SetPropertyWirelessEnabled(enabled)
if err != nil {
return fmt.Errorf("failed to set WiFi enabled: %w", err)
}
b.stateMutex.Lock()
b.state.WiFiEnabled = enabled
b.stateMutex.Unlock()
if b.onStateChange != nil {
b.onStateChange()
}
return nil
}
func (b *NetworkManagerBackend) ScanWiFi() error {
if b.wifiDevice == nil {
return fmt.Errorf("no WiFi device available")
}
b.stateMutex.RLock()
enabled := b.state.WiFiEnabled
b.stateMutex.RUnlock()
if !enabled {
return fmt.Errorf("WiFi is disabled")
}
if err := b.ensureWiFiDevice(); err != nil {
return err
}
w := b.wifiDev.(gonetworkmanager.DeviceWireless)
err := w.RequestScan()
if err != nil {
return fmt.Errorf("scan request failed: %w", err)
}
_, err = b.updateWiFiNetworks()
return err
}
func (b *NetworkManagerBackend) GetWiFiNetworkDetails(ssid string) (*NetworkInfoResponse, error) {
if b.wifiDevice == nil {
return nil, fmt.Errorf("no WiFi device available")
}
if err := b.ensureWiFiDevice(); err != nil {
return nil, err
}
wifiDev := b.wifiDev
w := wifiDev.(gonetworkmanager.DeviceWireless)
apPaths, err := w.GetAccessPoints()
if err != nil {
return nil, fmt.Errorf("failed to get access points: %w", err)
}
s := b.settings
if s == nil {
s, err = gonetworkmanager.NewSettings()
if err != nil {
return nil, fmt.Errorf("failed to get settings: %w", err)
}
b.settings = s
}
settingsMgr := s.(gonetworkmanager.Settings)
connections, err := settingsMgr.ListConnections()
if err != nil {
return nil, fmt.Errorf("failed to get connections: %w", err)
}
savedSSIDs := make(map[string]bool)
autoconnectMap := make(map[string]bool)
for _, conn := range connections {
connSettings, err := conn.GetSettings()
if err != nil {
continue
}
if connMeta, ok := connSettings["connection"]; ok {
if connType, ok := connMeta["type"].(string); ok && connType == "802-11-wireless" {
if wifiSettings, ok := connSettings["802-11-wireless"]; ok {
if ssidBytes, ok := wifiSettings["ssid"].([]byte); ok {
savedSSID := string(ssidBytes)
savedSSIDs[savedSSID] = true
autoconnect := true
if ac, ok := connMeta["autoconnect"].(bool); ok {
autoconnect = ac
}
autoconnectMap[savedSSID] = autoconnect
}
}
}
}
}
b.stateMutex.RLock()
currentSSID := b.state.WiFiSSID
currentBSSID := b.state.WiFiBSSID
b.stateMutex.RUnlock()
var bands []WiFiNetwork
for _, ap := range apPaths {
apSSID, err := ap.GetPropertySSID()
if err != nil || apSSID != ssid {
continue
}
strength, _ := ap.GetPropertyStrength()
flags, _ := ap.GetPropertyFlags()
wpaFlags, _ := ap.GetPropertyWPAFlags()
rsnFlags, _ := ap.GetPropertyRSNFlags()
freq, _ := ap.GetPropertyFrequency()
maxBitrate, _ := ap.GetPropertyMaxBitrate()
bssid, _ := ap.GetPropertyHWAddress()
mode, _ := ap.GetPropertyMode()
secured := flags != uint32(gonetworkmanager.Nm80211APFlagsNone) ||
wpaFlags != uint32(gonetworkmanager.Nm80211APSecNone) ||
rsnFlags != uint32(gonetworkmanager.Nm80211APSecNone)
enterprise := (rsnFlags&uint32(gonetworkmanager.Nm80211APSecKeyMgmt8021X) != 0) ||
(wpaFlags&uint32(gonetworkmanager.Nm80211APSecKeyMgmt8021X) != 0)
var modeStr string
switch mode {
case gonetworkmanager.Nm80211ModeAdhoc:
modeStr = "adhoc"
case gonetworkmanager.Nm80211ModeInfra:
modeStr = "infrastructure"
case gonetworkmanager.Nm80211ModeAp:
modeStr = "ap"
default:
modeStr = "unknown"
}
channel := frequencyToChannel(freq)
network := WiFiNetwork{
SSID: ssid,
BSSID: bssid,
Signal: strength,
Secured: secured,
Enterprise: enterprise,
Connected: ssid == currentSSID && bssid == currentBSSID,
Saved: savedSSIDs[ssid],
Autoconnect: autoconnectMap[ssid],
Frequency: freq,
Mode: modeStr,
Rate: maxBitrate / 1000,
Channel: channel,
}
bands = append(bands, network)
}
if len(bands) == 0 {
return nil, fmt.Errorf("network not found: %s", ssid)
}
sort.Slice(bands, func(i, j int) bool {
if bands[i].Connected && !bands[j].Connected {
return true
}
if !bands[i].Connected && bands[j].Connected {
return false
}
return bands[i].Signal > bands[j].Signal
})
return &NetworkInfoResponse{
SSID: ssid,
Bands: bands,
}, nil
}
func (b *NetworkManagerBackend) ConnectWiFi(req ConnectionRequest) error {
if b.wifiDevice == nil {
return fmt.Errorf("no WiFi device available")
}
b.stateMutex.RLock()
alreadyConnected := b.state.WiFiConnected && b.state.WiFiSSID == req.SSID
b.stateMutex.RUnlock()
if alreadyConnected && !req.Interactive {
return nil
}
b.stateMutex.Lock()
b.state.IsConnecting = true
b.state.ConnectingSSID = req.SSID
b.state.LastError = ""
b.stateMutex.Unlock()
if b.onStateChange != nil {
b.onStateChange()
}
nm := b.nmConn.(gonetworkmanager.NetworkManager)
existingConn, err := b.findConnection(req.SSID)
if err == nil && existingConn != nil {
dev := b.wifiDevice.(gonetworkmanager.Device)
_, err := nm.ActivateConnection(existingConn, dev, nil)
if err != nil {
log.Warnf("[ConnectWiFi] Failed to activate existing connection: %v", err)
b.stateMutex.Lock()
b.state.IsConnecting = false
b.state.ConnectingSSID = ""
b.state.LastError = fmt.Sprintf("failed to activate connection: %v", err)
b.stateMutex.Unlock()
if b.onStateChange != nil {
b.onStateChange()
}
return fmt.Errorf("failed to activate connection: %w", err)
}
return nil
}
if err := b.createAndConnectWiFi(req); err != nil {
log.Warnf("[ConnectWiFi] Failed to create and connect: %v", err)
b.stateMutex.Lock()
b.state.IsConnecting = false
b.state.ConnectingSSID = ""
b.state.LastError = err.Error()
b.stateMutex.Unlock()
if b.onStateChange != nil {
b.onStateChange()
}
return err
}
return nil
}
func (b *NetworkManagerBackend) DisconnectWiFi() error {
if b.wifiDevice == nil {
return fmt.Errorf("no WiFi device available")
}
dev := b.wifiDevice.(gonetworkmanager.Device)
err := dev.Disconnect()
if err != nil {
return fmt.Errorf("failed to disconnect: %w", err)
}
b.updateWiFiState()
b.updatePrimaryConnection()
if b.onStateChange != nil {
b.onStateChange()
}
return nil
}
func (b *NetworkManagerBackend) ForgetWiFiNetwork(ssid string) error {
conn, err := b.findConnection(ssid)
if err != nil {
return fmt.Errorf("connection not found: %w", err)
}
b.stateMutex.RLock()
currentSSID := b.state.WiFiSSID
isConnected := b.state.WiFiConnected
b.stateMutex.RUnlock()
err = conn.Delete()
if err != nil {
return fmt.Errorf("failed to delete connection: %w", err)
}
if isConnected && currentSSID == ssid {
b.stateMutex.Lock()
b.state.WiFiConnected = false
b.state.WiFiSSID = ""
b.state.WiFiBSSID = ""
b.state.WiFiSignal = 0
b.state.WiFiIP = ""
b.state.NetworkStatus = StatusDisconnected
b.stateMutex.Unlock()
}
b.updateWiFiNetworks()
if b.onStateChange != nil {
b.onStateChange()
}
return nil
}
func (b *NetworkManagerBackend) IsConnectingTo(ssid string) bool {
b.stateMutex.RLock()
defer b.stateMutex.RUnlock()
return b.state.IsConnecting && b.state.ConnectingSSID == ssid
}
func (b *NetworkManagerBackend) updateWiFiNetworks() ([]WiFiNetwork, error) {
if b.wifiDevice == nil {
return nil, fmt.Errorf("no WiFi device available")
}
if err := b.ensureWiFiDevice(); err != nil {
return nil, err
}
wifiDev := b.wifiDev
w := wifiDev.(gonetworkmanager.DeviceWireless)
apPaths, err := w.GetAccessPoints()
if err != nil {
return nil, fmt.Errorf("failed to get access points: %w", err)
}
s := b.settings
if s == nil {
s, err = gonetworkmanager.NewSettings()
if err != nil {
return nil, fmt.Errorf("failed to get settings: %w", err)
}
b.settings = s
}
settingsMgr := s.(gonetworkmanager.Settings)
connections, err := settingsMgr.ListConnections()
if err != nil {
return nil, fmt.Errorf("failed to get connections: %w", err)
}
savedSSIDs := make(map[string]bool)
autoconnectMap := make(map[string]bool)
for _, conn := range connections {
connSettings, err := conn.GetSettings()
if err != nil {
continue
}
if connMeta, ok := connSettings["connection"]; ok {
if connType, ok := connMeta["type"].(string); ok && connType == "802-11-wireless" {
if wifiSettings, ok := connSettings["802-11-wireless"]; ok {
if ssidBytes, ok := wifiSettings["ssid"].([]byte); ok {
ssid := string(ssidBytes)
savedSSIDs[ssid] = true
autoconnect := true
if ac, ok := connMeta["autoconnect"].(bool); ok {
autoconnect = ac
}
autoconnectMap[ssid] = autoconnect
}
}
}
}
}
b.stateMutex.RLock()
currentSSID := b.state.WiFiSSID
b.stateMutex.RUnlock()
seenSSIDs := make(map[string]*WiFiNetwork)
networks := []WiFiNetwork{}
for _, ap := range apPaths {
ssid, err := ap.GetPropertySSID()
if err != nil || ssid == "" {
continue
}
if existing, exists := seenSSIDs[ssid]; exists {
strength, _ := ap.GetPropertyStrength()
if strength > existing.Signal {
existing.Signal = strength
freq, _ := ap.GetPropertyFrequency()
existing.Frequency = freq
bssid, _ := ap.GetPropertyHWAddress()
existing.BSSID = bssid
}
continue
}
strength, _ := ap.GetPropertyStrength()
flags, _ := ap.GetPropertyFlags()
wpaFlags, _ := ap.GetPropertyWPAFlags()
rsnFlags, _ := ap.GetPropertyRSNFlags()
freq, _ := ap.GetPropertyFrequency()
maxBitrate, _ := ap.GetPropertyMaxBitrate()
bssid, _ := ap.GetPropertyHWAddress()
mode, _ := ap.GetPropertyMode()
secured := flags != uint32(gonetworkmanager.Nm80211APFlagsNone) ||
wpaFlags != uint32(gonetworkmanager.Nm80211APSecNone) ||
rsnFlags != uint32(gonetworkmanager.Nm80211APSecNone)
enterprise := (rsnFlags&uint32(gonetworkmanager.Nm80211APSecKeyMgmt8021X) != 0) ||
(wpaFlags&uint32(gonetworkmanager.Nm80211APSecKeyMgmt8021X) != 0)
var modeStr string
switch mode {
case gonetworkmanager.Nm80211ModeAdhoc:
modeStr = "adhoc"
case gonetworkmanager.Nm80211ModeInfra:
modeStr = "infrastructure"
case gonetworkmanager.Nm80211ModeAp:
modeStr = "ap"
default:
modeStr = "unknown"
}
channel := frequencyToChannel(freq)
network := WiFiNetwork{
SSID: ssid,
BSSID: bssid,
Signal: strength,
Secured: secured,
Enterprise: enterprise,
Connected: ssid == currentSSID,
Saved: savedSSIDs[ssid],
Autoconnect: autoconnectMap[ssid],
Frequency: freq,
Mode: modeStr,
Rate: maxBitrate / 1000,
Channel: channel,
}
seenSSIDs[ssid] = &network
networks = append(networks, network)
}
sortWiFiNetworks(networks)
b.stateMutex.Lock()
b.state.WiFiNetworks = networks
b.stateMutex.Unlock()
return networks, nil
}
func (b *NetworkManagerBackend) findConnection(ssid string) (gonetworkmanager.Connection, error) {
s := b.settings
if s == nil {
var err error
s, err = gonetworkmanager.NewSettings()
if err != nil {
return nil, err
}
b.settings = s
}
settings := s.(gonetworkmanager.Settings)
connections, err := settings.ListConnections()
if err != nil {
return nil, err
}
ssidBytes := []byte(ssid)
for _, conn := range connections {
connSettings, err := conn.GetSettings()
if err != nil {
continue
}
if connMeta, ok := connSettings["connection"]; ok {
if connType, ok := connMeta["type"].(string); ok && connType == "802-11-wireless" {
if wifiSettings, ok := connSettings["802-11-wireless"]; ok {
if candidateSSID, ok := wifiSettings["ssid"].([]byte); ok {
if bytes.Equal(candidateSSID, ssidBytes) {
return conn, nil
}
}
}
}
}
}
return nil, fmt.Errorf("connection not found")
}
func (b *NetworkManagerBackend) createAndConnectWiFi(req ConnectionRequest) error {
if b.wifiDevice == nil {
return fmt.Errorf("no WiFi device available")
}
nm := b.nmConn.(gonetworkmanager.NetworkManager)
dev := b.wifiDevice.(gonetworkmanager.Device)
if err := b.ensureWiFiDevice(); err != nil {
return err
}
wifiDev := b.wifiDev
w := wifiDev.(gonetworkmanager.DeviceWireless)
apPaths, err := w.GetAccessPoints()
if err != nil {
return fmt.Errorf("failed to get access points: %w", err)
}
var targetAP gonetworkmanager.AccessPoint
for _, ap := range apPaths {
ssid, err := ap.GetPropertySSID()
if err != nil || ssid != req.SSID {
continue
}
targetAP = ap
break
}
if targetAP == nil {
return fmt.Errorf("access point not found: %s", req.SSID)
}
flags, _ := targetAP.GetPropertyFlags()
wpaFlags, _ := targetAP.GetPropertyWPAFlags()
rsnFlags, _ := targetAP.GetPropertyRSNFlags()
const KeyMgmt8021x = uint32(512)
const KeyMgmtPsk = uint32(256)
const KeyMgmtSae = uint32(1024)
isEnterprise := (wpaFlags&KeyMgmt8021x) != 0 || (rsnFlags&KeyMgmt8021x) != 0
isPsk := (wpaFlags&KeyMgmtPsk) != 0 || (rsnFlags&KeyMgmtPsk) != 0
isSae := (wpaFlags&KeyMgmtSae) != 0 || (rsnFlags&KeyMgmtSae) != 0
secured := flags != uint32(gonetworkmanager.Nm80211APFlagsNone) ||
wpaFlags != uint32(gonetworkmanager.Nm80211APSecNone) ||
rsnFlags != uint32(gonetworkmanager.Nm80211APSecNone)
if isEnterprise {
log.Infof("[createAndConnectWiFi] Enterprise network detected (802.1x) - SSID: %s, interactive: %v",
req.SSID, req.Interactive)
}
settings := make(map[string]map[string]interface{})
settings["connection"] = map[string]interface{}{
"id": req.SSID,
"type": "802-11-wireless",
"autoconnect": true,
}
settings["ipv4"] = map[string]interface{}{"method": "auto"}
settings["ipv6"] = map[string]interface{}{"method": "auto"}
if secured {
settings["802-11-wireless"] = map[string]interface{}{
"ssid": []byte(req.SSID),
"mode": "infrastructure",
"security": "802-11-wireless-security",
}
switch {
case isEnterprise || req.Username != "":
settings["802-11-wireless-security"] = map[string]interface{}{
"key-mgmt": "wpa-eap",
}
x := map[string]interface{}{
"eap": []string{"peap"},
"phase2-auth": "mschapv2",
"system-ca-certs": false,
"password-flags": uint32(0),
}
if req.Username != "" {
x["identity"] = req.Username
}
if req.Password != "" {
x["password"] = req.Password
}
if req.AnonymousIdentity != "" {
x["anonymous-identity"] = req.AnonymousIdentity
}
if req.DomainSuffixMatch != "" {
x["domain-suffix-match"] = req.DomainSuffixMatch
}
settings["802-1x"] = x
log.Infof("[createAndConnectWiFi] WPA-EAP settings: eap=peap, phase2-auth=mschapv2, identity=%s, interactive=%v, system-ca-certs=%v, domain-suffix-match=%q",
req.Username, req.Interactive, x["system-ca-certs"], req.DomainSuffixMatch)
case isPsk:
sec := map[string]interface{}{
"key-mgmt": "wpa-psk",
"psk-flags": uint32(0),
}
if !req.Interactive {
sec["psk"] = req.Password
}
settings["802-11-wireless-security"] = sec
case isSae:
sec := map[string]interface{}{
"key-mgmt": "sae",
"pmf": int32(3),
"psk-flags": uint32(0),
}
if !req.Interactive {
sec["psk"] = req.Password
}
settings["802-11-wireless-security"] = sec
default:
return fmt.Errorf("secured network but not SAE/PSK/802.1X (rsn=0x%x wpa=0x%x)", rsnFlags, wpaFlags)
}
} else {
settings["802-11-wireless"] = map[string]interface{}{
"ssid": []byte(req.SSID),
"mode": "infrastructure",
}
}
if req.Interactive {
s := b.settings
if s == nil {
var settingsErr error
s, settingsErr = gonetworkmanager.NewSettings()
if settingsErr != nil {
return fmt.Errorf("failed to get settings manager: %w", settingsErr)
}
b.settings = s
}
settingsMgr := s.(gonetworkmanager.Settings)
conn, err := settingsMgr.AddConnection(settings)
if err != nil {
return fmt.Errorf("failed to add connection: %w", err)
}
if isEnterprise {
log.Infof("[createAndConnectWiFi] Enterprise connection added, activating (secret agent will be called)")
}
_, err = nm.ActivateWirelessConnection(conn, dev, targetAP)
if err != nil {
return fmt.Errorf("failed to activate connection: %w", err)
}
log.Infof("[createAndConnectWiFi] Connection activation initiated, waiting for NetworkManager state changes...")
} else {
_, err = nm.AddAndActivateWirelessConnection(settings, dev, targetAP)
if err != nil {
return fmt.Errorf("failed to connect: %w", err)
}
log.Infof("[createAndConnectWiFi] Connection activation initiated, waiting for NetworkManager state changes...")
}
return nil
}
func (b *NetworkManagerBackend) SetWiFiAutoconnect(ssid string, autoconnect bool) error {
conn, err := b.findConnection(ssid)
if err != nil {
return fmt.Errorf("connection not found: %w", err)
}
settings, err := conn.GetSettings()
if err != nil {
return fmt.Errorf("failed to get connection settings: %w", err)
}
if connMeta, ok := settings["connection"]; ok {
connMeta["autoconnect"] = autoconnect
} else {
return fmt.Errorf("connection metadata not found")
}
if ipv4, ok := settings["ipv4"]; ok {
delete(ipv4, "addresses")
delete(ipv4, "routes")
delete(ipv4, "dns")
}
if ipv6, ok := settings["ipv6"]; ok {
delete(ipv6, "addresses")
delete(ipv6, "routes")
delete(ipv6, "dns")
}
err = conn.Update(settings)
if err != nil {
return fmt.Errorf("failed to update connection: %w", err)
}
b.updateWiFiNetworks()
if b.onStateChange != nil {
b.onStateChange()
}
return nil
}

View File

@@ -0,0 +1,198 @@
package network
import (
"testing"
mock_gonetworkmanager "github.com/AvengeMedia/danklinux/internal/mocks/github.com/Wifx/gonetworkmanager/v2"
"github.com/stretchr/testify/assert"
)
func TestNetworkManagerBackend_GetWiFiEnabled(t *testing.T) {
mockNM := mock_gonetworkmanager.NewMockNetworkManager(t)
backend, err := NewNetworkManagerBackend(mockNM)
assert.NoError(t, err)
mockNM.EXPECT().GetPropertyWirelessEnabled().Return(true, nil)
enabled, err := backend.GetWiFiEnabled()
assert.NoError(t, err)
assert.True(t, enabled)
}
func TestNetworkManagerBackend_SetWiFiEnabled(t *testing.T) {
backend, err := NewNetworkManagerBackend()
if err != nil {
t.Skipf("NetworkManager not available: %v", err)
}
originalState, err := backend.GetWiFiEnabled()
if err != nil {
t.Skipf("Cannot get WiFi state: %v", err)
}
defer func() {
backend.SetWiFiEnabled(originalState)
}()
err = backend.SetWiFiEnabled(!originalState)
assert.NoError(t, err)
backend.stateMutex.RLock()
assert.Equal(t, !originalState, backend.state.WiFiEnabled)
backend.stateMutex.RUnlock()
}
func TestNetworkManagerBackend_ScanWiFi_NoDevice(t *testing.T) {
backend, err := NewNetworkManagerBackend()
if err != nil {
t.Skipf("NetworkManager not available: %v", err)
}
backend.wifiDevice = nil
err = backend.ScanWiFi()
assert.Error(t, err)
assert.Contains(t, err.Error(), "no WiFi device available")
}
func TestNetworkManagerBackend_ScanWiFi_Disabled(t *testing.T) {
backend, err := NewNetworkManagerBackend()
if err != nil {
t.Skipf("NetworkManager not available: %v", err)
}
if backend.wifiDevice == nil {
t.Skip("No WiFi device available")
}
backend.stateMutex.Lock()
backend.state.WiFiEnabled = false
backend.stateMutex.Unlock()
err = backend.ScanWiFi()
assert.Error(t, err)
assert.Contains(t, err.Error(), "WiFi is disabled")
}
func TestNetworkManagerBackend_GetWiFiNetworkDetails_NoDevice(t *testing.T) {
backend, err := NewNetworkManagerBackend()
if err != nil {
t.Skipf("NetworkManager not available: %v", err)
}
backend.wifiDevice = nil
_, err = backend.GetWiFiNetworkDetails("TestNetwork")
assert.Error(t, err)
assert.Contains(t, err.Error(), "no WiFi device available")
}
func TestNetworkManagerBackend_ConnectWiFi_NoDevice(t *testing.T) {
backend, err := NewNetworkManagerBackend()
if err != nil {
t.Skipf("NetworkManager not available: %v", err)
}
backend.wifiDevice = nil
req := ConnectionRequest{SSID: "TestNetwork", Password: "password"}
err = backend.ConnectWiFi(req)
assert.Error(t, err)
assert.Contains(t, err.Error(), "no WiFi device available")
}
func TestNetworkManagerBackend_ConnectWiFi_AlreadyConnected(t *testing.T) {
backend, err := NewNetworkManagerBackend()
if err != nil {
t.Skipf("NetworkManager not available: %v", err)
}
if backend.wifiDevice == nil {
t.Skip("No WiFi device available")
}
backend.stateMutex.Lock()
backend.state.WiFiConnected = true
backend.state.WiFiSSID = "TestNetwork"
backend.stateMutex.Unlock()
req := ConnectionRequest{SSID: "TestNetwork", Password: "password"}
err = backend.ConnectWiFi(req)
assert.NoError(t, err)
}
func TestNetworkManagerBackend_DisconnectWiFi_NoDevice(t *testing.T) {
backend, err := NewNetworkManagerBackend()
if err != nil {
t.Skipf("NetworkManager not available: %v", err)
}
backend.wifiDevice = nil
err = backend.DisconnectWiFi()
assert.Error(t, err)
assert.Contains(t, err.Error(), "no WiFi device available")
}
func TestNetworkManagerBackend_IsConnectingTo(t *testing.T) {
backend, err := NewNetworkManagerBackend()
if err != nil {
t.Skipf("NetworkManager not available: %v", err)
}
backend.stateMutex.Lock()
backend.state.IsConnecting = true
backend.state.ConnectingSSID = "TestNetwork"
backend.stateMutex.Unlock()
assert.True(t, backend.IsConnectingTo("TestNetwork"))
assert.False(t, backend.IsConnectingTo("OtherNetwork"))
}
func TestNetworkManagerBackend_IsConnectingTo_NotConnecting(t *testing.T) {
backend, err := NewNetworkManagerBackend()
if err != nil {
t.Skipf("NetworkManager not available: %v", err)
}
backend.stateMutex.Lock()
backend.state.IsConnecting = false
backend.state.ConnectingSSID = ""
backend.stateMutex.Unlock()
assert.False(t, backend.IsConnectingTo("TestNetwork"))
}
func TestNetworkManagerBackend_UpdateWiFiNetworks_NoDevice(t *testing.T) {
backend, err := NewNetworkManagerBackend()
if err != nil {
t.Skipf("NetworkManager not available: %v", err)
}
backend.wifiDevice = nil
_, err = backend.updateWiFiNetworks()
assert.Error(t, err)
assert.Contains(t, err.Error(), "no WiFi device available")
}
func TestNetworkManagerBackend_FindConnection_NoSettings(t *testing.T) {
backend, err := NewNetworkManagerBackend()
if err != nil {
t.Skipf("NetworkManager not available: %v", err)
}
backend.settings = nil
_, err = backend.findConnection("NonExistentNetwork")
assert.Error(t, err)
}
func TestNetworkManagerBackend_CreateAndConnectWiFi_NoDevice(t *testing.T) {
backend, err := NewNetworkManagerBackend()
if err != nil {
t.Skipf("NetworkManager not available: %v", err)
}
backend.wifiDevice = nil
backend.wifiDev = nil
req := ConnectionRequest{SSID: "TestNetwork", Password: "password"}
err = backend.createAndConnectWiFi(req)
assert.Error(t, err)
assert.Contains(t, err.Error(), "no WiFi device available")
}

View File

@@ -0,0 +1,22 @@
package network
import (
"context"
"crypto/rand"
"encoding/hex"
)
type PromptBroker interface {
Ask(ctx context.Context, req PromptRequest) (token string, err error)
Wait(ctx context.Context, token string) (PromptReply, error)
Resolve(token string, reply PromptReply) error
Cancel(path string, setting string) error
}
func generateToken() (string, error) {
bytes := make([]byte, 16)
if _, err := rand.Read(bytes); err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}

View File

@@ -0,0 +1,109 @@
package network_test
import (
"errors"
"testing"
mocks_network "github.com/AvengeMedia/danklinux/internal/mocks/network"
"github.com/AvengeMedia/danklinux/internal/server/network"
"github.com/stretchr/testify/assert"
)
func TestConnectionRequest_Validation(t *testing.T) {
t.Run("basic WiFi connection", func(t *testing.T) {
req := network.ConnectionRequest{
SSID: "TestNetwork",
Password: "testpass123",
}
assert.NotEmpty(t, req.SSID)
assert.NotEmpty(t, req.Password)
assert.Empty(t, req.Username)
})
t.Run("enterprise WiFi connection", func(t *testing.T) {
req := network.ConnectionRequest{
SSID: "EnterpriseNetwork",
Password: "testpass123",
Username: "testuser",
}
assert.NotEmpty(t, req.SSID)
assert.NotEmpty(t, req.Password)
assert.NotEmpty(t, req.Username)
})
t.Run("open WiFi connection", func(t *testing.T) {
req := network.ConnectionRequest{
SSID: "OpenNetwork",
}
assert.NotEmpty(t, req.SSID)
assert.Empty(t, req.Password)
assert.Empty(t, req.Username)
})
}
func TestManager_ConnectWiFi_NoDevice(t *testing.T) {
backend := mocks_network.NewMockBackend(t)
req := network.ConnectionRequest{
SSID: "TestNetwork",
Password: "testpass123",
}
backend.EXPECT().ConnectWiFi(req).Return(errors.New("no WiFi device available"))
manager := network.NewTestManager(backend, &network.NetworkState{})
err := manager.ConnectWiFi(req)
assert.Error(t, err)
assert.Contains(t, err.Error(), "no WiFi device available")
}
func TestManager_DisconnectWiFi_NoDevice(t *testing.T) {
backend := mocks_network.NewMockBackend(t)
backend.EXPECT().DisconnectWiFi().Return(errors.New("no WiFi device available"))
manager := network.NewTestManager(backend, &network.NetworkState{})
err := manager.DisconnectWiFi()
assert.Error(t, err)
assert.Contains(t, err.Error(), "no WiFi device available")
}
func TestManager_ForgetWiFiNetwork_NotFound(t *testing.T) {
backend := mocks_network.NewMockBackend(t)
backend.EXPECT().ForgetWiFiNetwork("NonExistentNetwork").Return(errors.New("connection not found"))
manager := network.NewTestManager(backend, &network.NetworkState{})
err := manager.ForgetWiFiNetwork("NonExistentNetwork")
assert.Error(t, err)
assert.Contains(t, err.Error(), "connection not found")
}
func TestManager_ConnectEthernet_NoDevice(t *testing.T) {
backend := mocks_network.NewMockBackend(t)
backend.EXPECT().ConnectEthernet().Return(errors.New("no ethernet device available"))
manager := network.NewTestManager(backend, &network.NetworkState{})
err := manager.ConnectEthernet()
assert.Error(t, err)
assert.Contains(t, err.Error(), "no ethernet device available")
}
func TestManager_DisconnectEthernet_NoDevice(t *testing.T) {
backend := mocks_network.NewMockBackend(t)
backend.EXPECT().DisconnectEthernet().Return(errors.New("no ethernet device available"))
manager := network.NewTestManager(backend, &network.NetworkState{})
err := manager.DisconnectEthernet()
assert.Error(t, err)
assert.Contains(t, err.Error(), "no ethernet device available")
}
// Note: More comprehensive tests for connection operations would require
// mocking the NetworkManager D-Bus interfaces, which is beyond the scope
// of these unit tests. The tests above cover the basic error cases and
// validation logic. Integration tests would be needed for full coverage.

View File

@@ -0,0 +1,89 @@
package network
import (
"fmt"
"github.com/godbus/dbus/v5"
)
type BackendType int
const (
BackendNone BackendType = iota
BackendNetworkManager
BackendIwd
BackendConnMan
BackendNetworkd
)
func nameHasOwner(bus *dbus.Conn, name string) (bool, error) {
obj := bus.Object("org.freedesktop.DBus", "/org/freedesktop/DBus")
var owned bool
if err := obj.Call("org.freedesktop.DBus.NameHasOwner", 0, name).Store(&owned); err != nil {
return false, err
}
return owned, nil
}
type DetectResult struct {
Backend BackendType
HasNM bool
HasIwd bool
HasConnMan bool
HasWpaSupp bool
HasNetworkd bool
ChosenReason string
}
func DetectNetworkStack() (*DetectResult, error) {
bus, err := dbus.ConnectSystemBus()
if err != nil {
return nil, fmt.Errorf("connect system bus: %w", err)
}
defer bus.Close()
hasNM, _ := nameHasOwner(bus, "org.freedesktop.NetworkManager")
hasIwd, _ := nameHasOwner(bus, "net.connman.iwd")
hasConn, _ := nameHasOwner(bus, "net.connman")
hasWpa, _ := nameHasOwner(bus, "fi.w1.wpa_supplicant1")
hasNetworkd, _ := nameHasOwner(bus, "org.freedesktop.network1")
res := &DetectResult{
HasNM: hasNM,
HasIwd: hasIwd,
HasConnMan: hasConn,
HasWpaSupp: hasWpa,
HasNetworkd: hasNetworkd,
}
switch {
case hasNM:
res.Backend = BackendNetworkManager
if hasIwd {
res.ChosenReason = "NetworkManager present; iwd also running (likely NM's Wi-Fi backend). Using NM API."
} else {
res.ChosenReason = "NetworkManager present. Using NM API."
}
case hasConn && hasIwd:
res.Backend = BackendConnMan
res.ChosenReason = "ConnMan + iwd detected. Use ConnMan API (iwd is its Wi-Fi daemon)."
case hasIwd && hasNetworkd:
res.Backend = BackendNetworkd
res.ChosenReason = "iwd + systemd-networkd detected. Using iwd for Wi-Fi association and networkd for IP/DHCP."
case hasIwd:
res.Backend = BackendIwd
res.ChosenReason = "iwd detected without NM/ConnMan. Using iwd API."
case hasNetworkd:
res.Backend = BackendNetworkd
res.ChosenReason = "systemd-networkd detected (no NM/ConnMan). Using networkd for L3 and wired."
default:
res.Backend = BackendNone
if hasWpa {
res.ChosenReason = "No NM/ConnMan/iwd; wpa_supplicant present. Consider a wpa_supplicant path."
} else {
res.ChosenReason = "No known network manager bus names found."
}
}
return res, nil
}

View File

@@ -0,0 +1,34 @@
package network
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestBackendType_Constants(t *testing.T) {
assert.Equal(t, BackendType(0), BackendNone)
assert.Equal(t, BackendType(1), BackendNetworkManager)
assert.Equal(t, BackendType(2), BackendIwd)
assert.Equal(t, BackendType(3), BackendConnMan)
assert.Equal(t, BackendType(4), BackendNetworkd)
}
func TestDetectResult_HasNetworkdField(t *testing.T) {
result := &DetectResult{
Backend: BackendNetworkd,
HasNetworkd: true,
HasIwd: true,
}
assert.True(t, result.HasNetworkd)
assert.True(t, result.HasIwd)
assert.Equal(t, BackendNetworkd, result.Backend)
}
func TestDetectNetworkStack_Integration(t *testing.T) {
result, err := DetectNetworkStack()
assert.NoError(t, err)
assert.NotNil(t, result)
assert.NotEmpty(t, result.ChosenReason)
}

View File

@@ -0,0 +1,487 @@
package network
import (
"encoding/json"
"fmt"
"net"
"github.com/AvengeMedia/danklinux/internal/log"
"github.com/AvengeMedia/danklinux/internal/server/models"
)
type Request struct {
ID int `json:"id,omitempty"`
Method string `json:"method"`
Params map[string]interface{} `json:"params,omitempty"`
}
type SuccessResult struct {
Success bool `json:"success"`
Message string `json:"message"`
}
func HandleRequest(conn net.Conn, req Request, manager *Manager) {
switch req.Method {
case "network.getState":
handleGetState(conn, req, manager)
case "network.wifi.scan":
handleScanWiFi(conn, req, manager)
case "network.wifi.networks":
handleGetWiFiNetworks(conn, req, manager)
case "network.wifi.connect":
handleConnectWiFi(conn, req, manager)
case "network.wifi.disconnect":
handleDisconnectWiFi(conn, req, manager)
case "network.wifi.forget":
handleForgetWiFi(conn, req, manager)
case "network.wifi.toggle":
handleToggleWiFi(conn, req, manager)
case "network.wifi.enable":
handleEnableWiFi(conn, req, manager)
case "network.wifi.disable":
handleDisableWiFi(conn, req, manager)
case "network.ethernet.connect.config":
handleConnectEthernetSpecificConfig(conn, req, manager)
case "network.ethernet.connect":
handleConnectEthernet(conn, req, manager)
case "network.ethernet.disconnect":
handleDisconnectEthernet(conn, req, manager)
case "network.preference.set":
handleSetPreference(conn, req, manager)
case "network.info":
handleGetNetworkInfo(conn, req, manager)
case "network.ethernet.info":
handleGetWiredNetworkInfo(conn, req, manager)
case "network.subscribe":
handleSubscribe(conn, req, manager)
case "network.credentials.submit":
handleCredentialsSubmit(conn, req, manager)
case "network.credentials.cancel":
handleCredentialsCancel(conn, req, manager)
case "network.vpn.profiles":
handleListVPNProfiles(conn, req, manager)
case "network.vpn.active":
handleListActiveVPN(conn, req, manager)
case "network.vpn.connect":
handleConnectVPN(conn, req, manager)
case "network.vpn.disconnect":
handleDisconnectVPN(conn, req, manager)
case "network.vpn.disconnectAll":
handleDisconnectAllVPN(conn, req, manager)
case "network.vpn.clearCredentials":
handleClearVPNCredentials(conn, req, manager)
case "network.wifi.setAutoconnect":
handleSetWiFiAutoconnect(conn, req, manager)
default:
models.RespondError(conn, req.ID, fmt.Sprintf("unknown method: %s", req.Method))
}
}
func handleCredentialsSubmit(conn net.Conn, req Request, manager *Manager) {
token, ok := req.Params["token"].(string)
if !ok {
log.Warnf("handleCredentialsSubmit: missing or invalid token parameter")
models.RespondError(conn, req.ID, "missing or invalid 'token' parameter")
return
}
secretsRaw, ok := req.Params["secrets"].(map[string]interface{})
if !ok {
log.Warnf("handleCredentialsSubmit: missing or invalid secrets parameter")
models.RespondError(conn, req.ID, "missing or invalid 'secrets' parameter")
return
}
secrets := make(map[string]string)
for k, v := range secretsRaw {
if str, ok := v.(string); ok {
secrets[k] = str
}
}
save := true
if saveParam, ok := req.Params["save"].(bool); ok {
save = saveParam
}
if err := manager.SubmitCredentials(token, secrets, save); err != nil {
log.Warnf("handleCredentialsSubmit: failed to submit credentials: %v", err)
models.RespondError(conn, req.ID, err.Error())
return
}
log.Infof("handleCredentialsSubmit: credentials submitted successfully")
models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "credentials submitted"})
}
func handleCredentialsCancel(conn net.Conn, req Request, manager *Manager) {
token, ok := req.Params["token"].(string)
if !ok {
models.RespondError(conn, req.ID, "missing or invalid 'token' parameter")
return
}
if err := manager.CancelCredentials(token); err != nil {
models.RespondError(conn, req.ID, err.Error())
return
}
models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "credentials cancelled"})
}
func handleGetState(conn net.Conn, req Request, manager *Manager) {
state := manager.GetState()
models.Respond(conn, req.ID, state)
}
func handleScanWiFi(conn net.Conn, req Request, manager *Manager) {
if err := manager.ScanWiFi(); err != nil {
models.RespondError(conn, req.ID, err.Error())
return
}
models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "scanning"})
}
func handleGetWiFiNetworks(conn net.Conn, req Request, manager *Manager) {
networks := manager.GetWiFiNetworks()
models.Respond(conn, req.ID, networks)
}
func handleConnectWiFi(conn net.Conn, req Request, manager *Manager) {
ssid, ok := req.Params["ssid"].(string)
if !ok {
models.RespondError(conn, req.ID, "missing or invalid 'ssid' parameter")
return
}
var connReq ConnectionRequest
connReq.SSID = ssid
if password, ok := req.Params["password"].(string); ok {
connReq.Password = password
}
if username, ok := req.Params["username"].(string); ok {
connReq.Username = username
}
if interactive, ok := req.Params["interactive"].(bool); ok {
connReq.Interactive = interactive
} else {
state := manager.GetState()
alreadyConnected := state.WiFiConnected && state.WiFiSSID == ssid
if alreadyConnected {
connReq.Interactive = false
} else {
networkInfo, err := manager.GetNetworkInfo(ssid)
isSaved := err == nil && networkInfo.Saved
if isSaved {
connReq.Interactive = false
} else if err == nil && networkInfo.Secured && connReq.Password == "" && connReq.Username == "" {
connReq.Interactive = true
}
}
}
if anonymousIdentity, ok := req.Params["anonymousIdentity"].(string); ok {
connReq.AnonymousIdentity = anonymousIdentity
}
if domainSuffixMatch, ok := req.Params["domainSuffixMatch"].(string); ok {
connReq.DomainSuffixMatch = domainSuffixMatch
}
if err := manager.ConnectWiFi(connReq); err != nil {
models.RespondError(conn, req.ID, err.Error())
return
}
models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "connecting"})
}
func handleDisconnectWiFi(conn net.Conn, req Request, manager *Manager) {
if err := manager.DisconnectWiFi(); err != nil {
models.RespondError(conn, req.ID, err.Error())
return
}
models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "disconnected"})
}
func handleForgetWiFi(conn net.Conn, req Request, manager *Manager) {
ssid, ok := req.Params["ssid"].(string)
if !ok {
models.RespondError(conn, req.ID, "missing or invalid 'ssid' parameter")
return
}
if err := manager.ForgetWiFiNetwork(ssid); err != nil {
models.RespondError(conn, req.ID, err.Error())
return
}
models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "forgotten"})
}
func handleToggleWiFi(conn net.Conn, req Request, manager *Manager) {
if err := manager.ToggleWiFi(); err != nil {
models.RespondError(conn, req.ID, err.Error())
return
}
state := manager.GetState()
models.Respond(conn, req.ID, map[string]bool{"enabled": state.WiFiEnabled})
}
func handleEnableWiFi(conn net.Conn, req Request, manager *Manager) {
if err := manager.EnableWiFi(); err != nil {
models.RespondError(conn, req.ID, err.Error())
return
}
models.Respond(conn, req.ID, map[string]bool{"enabled": true})
}
func handleDisableWiFi(conn net.Conn, req Request, manager *Manager) {
if err := manager.DisableWiFi(); err != nil {
models.RespondError(conn, req.ID, err.Error())
return
}
models.Respond(conn, req.ID, map[string]bool{"enabled": false})
}
func handleConnectEthernetSpecificConfig(conn net.Conn, req Request, manager *Manager) {
uuid, ok := req.Params["uuid"].(string)
if !ok {
models.RespondError(conn, req.ID, "missing or invalid 'uuid' parameter")
return
}
if err := manager.activateConnection(uuid); err != nil {
models.RespondError(conn, req.ID, err.Error())
return
}
models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "connecting"})
}
func handleConnectEthernet(conn net.Conn, req Request, manager *Manager) {
if err := manager.ConnectEthernet(); err != nil {
models.RespondError(conn, req.ID, err.Error())
return
}
models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "connecting"})
}
func handleDisconnectEthernet(conn net.Conn, req Request, manager *Manager) {
if err := manager.DisconnectEthernet(); err != nil {
models.RespondError(conn, req.ID, err.Error())
return
}
models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "disconnected"})
}
func handleSetPreference(conn net.Conn, req Request, manager *Manager) {
preference, ok := req.Params["preference"].(string)
if !ok {
models.RespondError(conn, req.ID, "missing or invalid 'preference' parameter")
return
}
if err := manager.SetConnectionPreference(ConnectionPreference(preference)); err != nil {
models.RespondError(conn, req.ID, err.Error())
return
}
models.Respond(conn, req.ID, map[string]string{"preference": preference})
}
func handleGetNetworkInfo(conn net.Conn, req Request, manager *Manager) {
ssid, ok := req.Params["ssid"].(string)
if !ok {
models.RespondError(conn, req.ID, "missing or invalid 'ssid' parameter")
return
}
network, err := manager.GetNetworkInfoDetailed(ssid)
if err != nil {
models.RespondError(conn, req.ID, err.Error())
return
}
models.Respond(conn, req.ID, network)
}
func handleGetWiredNetworkInfo(conn net.Conn, req Request, manager *Manager) {
uuid, ok := req.Params["uuid"].(string)
if !ok {
models.RespondError(conn, req.ID, "missing or invalid 'uuid' parameter")
return
}
network, err := manager.GetWiredNetworkInfoDetailed(uuid)
if err != nil {
models.RespondError(conn, req.ID, err.Error())
return
}
models.Respond(conn, req.ID, network)
}
func handleSubscribe(conn net.Conn, req Request, manager *Manager) {
clientID := fmt.Sprintf("client-%p", conn)
stateChan := manager.Subscribe(clientID)
defer manager.Unsubscribe(clientID)
initialState := manager.GetState()
event := NetworkEvent{
Type: EventStateChanged,
Data: initialState,
}
if err := json.NewEncoder(conn).Encode(models.Response[NetworkEvent]{
ID: req.ID,
Result: &event,
}); err != nil {
return
}
for state := range stateChan {
event := NetworkEvent{
Type: EventStateChanged,
Data: state,
}
if err := json.NewEncoder(conn).Encode(models.Response[NetworkEvent]{
Result: &event,
}); err != nil {
return
}
}
}
func handleListVPNProfiles(conn net.Conn, req Request, manager *Manager) {
profiles, err := manager.ListVPNProfiles()
if err != nil {
log.Warnf("handleListVPNProfiles: failed to list profiles: %v", err)
models.RespondError(conn, req.ID, fmt.Sprintf("failed to list VPN profiles: %v", err))
return
}
models.Respond(conn, req.ID, profiles)
}
func handleListActiveVPN(conn net.Conn, req Request, manager *Manager) {
active, err := manager.ListActiveVPN()
if err != nil {
log.Warnf("handleListActiveVPN: failed to list active VPNs: %v", err)
models.RespondError(conn, req.ID, fmt.Sprintf("failed to list active VPNs: %v", err))
return
}
models.Respond(conn, req.ID, active)
}
func handleConnectVPN(conn net.Conn, req Request, manager *Manager) {
uuidOrName, ok := req.Params["uuidOrName"].(string)
if !ok {
name, nameOk := req.Params["name"].(string)
uuid, uuidOk := req.Params["uuid"].(string)
if nameOk {
uuidOrName = name
} else if uuidOk {
uuidOrName = uuid
} else {
log.Warnf("handleConnectVPN: missing uuidOrName/name/uuid parameter")
models.RespondError(conn, req.ID, "missing 'uuidOrName', 'name', or 'uuid' parameter")
return
}
}
// Default to true - only allow one VPN connection at a time
singleActive := true
if sa, ok := req.Params["singleActive"].(bool); ok {
singleActive = sa
}
if err := manager.ConnectVPN(uuidOrName, singleActive); err != nil {
log.Warnf("handleConnectVPN: failed to connect: %v", err)
models.RespondError(conn, req.ID, fmt.Sprintf("failed to connect VPN: %v", err))
return
}
models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "VPN connection initiated"})
}
func handleDisconnectVPN(conn net.Conn, req Request, manager *Manager) {
uuidOrName, ok := req.Params["uuidOrName"].(string)
if !ok {
name, nameOk := req.Params["name"].(string)
uuid, uuidOk := req.Params["uuid"].(string)
if nameOk {
uuidOrName = name
} else if uuidOk {
uuidOrName = uuid
} else {
log.Warnf("handleDisconnectVPN: missing uuidOrName/name/uuid parameter")
models.RespondError(conn, req.ID, "missing 'uuidOrName', 'name', or 'uuid' parameter")
return
}
}
if err := manager.DisconnectVPN(uuidOrName); err != nil {
log.Warnf("handleDisconnectVPN: failed to disconnect: %v", err)
models.RespondError(conn, req.ID, fmt.Sprintf("failed to disconnect VPN: %v", err))
return
}
models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "VPN disconnected"})
}
func handleDisconnectAllVPN(conn net.Conn, req Request, manager *Manager) {
if err := manager.DisconnectAllVPN(); err != nil {
log.Warnf("handleDisconnectAllVPN: failed: %v", err)
models.RespondError(conn, req.ID, fmt.Sprintf("failed to disconnect all VPNs: %v", err))
return
}
models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "All VPNs disconnected"})
}
func handleClearVPNCredentials(conn net.Conn, req Request, manager *Manager) {
uuidOrName, ok := req.Params["uuid"].(string)
if !ok {
uuidOrName, ok = req.Params["name"].(string)
}
if !ok {
uuidOrName, ok = req.Params["uuidOrName"].(string)
}
if !ok {
log.Warnf("handleClearVPNCredentials: missing uuidOrName/name/uuid parameter")
models.RespondError(conn, req.ID, "missing uuidOrName/name/uuid parameter")
return
}
if err := manager.ClearVPNCredentials(uuidOrName); err != nil {
log.Warnf("handleClearVPNCredentials: failed: %v", err)
models.RespondError(conn, req.ID, fmt.Sprintf("failed to clear VPN credentials: %v", err))
return
}
models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "VPN credentials cleared"})
}
func handleSetWiFiAutoconnect(conn net.Conn, req Request, manager *Manager) {
ssid, ok := req.Params["ssid"].(string)
if !ok {
models.RespondError(conn, req.ID, "missing or invalid 'ssid' parameter")
return
}
autoconnect, ok := req.Params["autoconnect"].(bool)
if !ok {
models.RespondError(conn, req.ID, "missing or invalid 'autoconnect' parameter")
return
}
if err := manager.SetWiFiAutoconnect(ssid, autoconnect); err != nil {
models.RespondError(conn, req.ID, fmt.Sprintf("failed to set autoconnect: %v", err))
return
}
models.Respond(conn, req.ID, SuccessResult{Success: true, Message: "autoconnect updated"})
}

View File

@@ -0,0 +1,263 @@
package network
import (
"bytes"
"encoding/json"
"net"
"testing"
"github.com/AvengeMedia/danklinux/internal/server/models"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type mockNetConn struct {
net.Conn
readBuf *bytes.Buffer
writeBuf *bytes.Buffer
closed bool
}
func newMockNetConn() *mockNetConn {
return &mockNetConn{
readBuf: &bytes.Buffer{},
writeBuf: &bytes.Buffer{},
}
}
func (m *mockNetConn) Read(b []byte) (n int, err error) {
return m.readBuf.Read(b)
}
func (m *mockNetConn) Write(b []byte) (n int, err error) {
return m.writeBuf.Write(b)
}
func (m *mockNetConn) Close() error {
m.closed = true
return nil
}
func TestRespondError_Network(t *testing.T) {
conn := newMockNetConn()
models.RespondError(conn, 123, "test error")
var resp models.Response[any]
err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, 123, resp.ID)
assert.Equal(t, "test error", resp.Error)
assert.Nil(t, resp.Result)
}
func TestRespond_Network(t *testing.T) {
conn := newMockNetConn()
result := SuccessResult{Success: true, Message: "test"}
models.Respond(conn, 123, result)
var resp models.Response[SuccessResult]
err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, 123, resp.ID)
assert.Empty(t, resp.Error)
require.NotNil(t, resp.Result)
assert.True(t, resp.Result.Success)
assert.Equal(t, "test", resp.Result.Message)
}
func TestHandleGetState(t *testing.T) {
manager := &Manager{
state: &NetworkState{
NetworkStatus: StatusWiFi,
WiFiSSID: "TestNetwork",
WiFiConnected: true,
},
}
conn := newMockNetConn()
req := Request{ID: 123, Method: "network.getState"}
handleGetState(conn, req, manager)
var resp models.Response[NetworkState]
err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, 123, resp.ID)
assert.Empty(t, resp.Error)
require.NotNil(t, resp.Result)
assert.Equal(t, StatusWiFi, resp.Result.NetworkStatus)
assert.Equal(t, "TestNetwork", resp.Result.WiFiSSID)
}
func TestHandleGetWiFiNetworks(t *testing.T) {
manager := &Manager{
state: &NetworkState{
WiFiNetworks: []WiFiNetwork{
{SSID: "Network1", Signal: 90},
{SSID: "Network2", Signal: 80},
},
},
}
conn := newMockNetConn()
req := Request{ID: 123, Method: "network.wifi.networks"}
handleGetWiFiNetworks(conn, req, manager)
var resp models.Response[[]WiFiNetwork]
err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, 123, resp.ID)
assert.Empty(t, resp.Error)
require.NotNil(t, resp.Result)
assert.Len(t, *resp.Result, 2)
assert.Equal(t, "Network1", (*resp.Result)[0].SSID)
}
func TestHandleConnectWiFi(t *testing.T) {
t.Run("missing ssid parameter", func(t *testing.T) {
manager := &Manager{
state: &NetworkState{},
}
conn := newMockNetConn()
req := Request{
ID: 123,
Method: "network.wifi.connect",
Params: map[string]interface{}{},
}
handleConnectWiFi(conn, req, manager)
var resp models.Response[any]
err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, 123, resp.ID)
assert.Contains(t, resp.Error, "missing or invalid 'ssid' parameter")
})
}
func TestHandleSetPreference(t *testing.T) {
t.Run("missing preference parameter", func(t *testing.T) {
manager := &Manager{
state: &NetworkState{},
}
conn := newMockNetConn()
req := Request{
ID: 123,
Method: "network.preference.set",
Params: map[string]interface{}{},
}
handleSetPreference(conn, req, manager)
var resp models.Response[any]
err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, 123, resp.ID)
assert.Contains(t, resp.Error, "missing or invalid 'preference' parameter")
})
}
func TestHandleGetNetworkInfo(t *testing.T) {
t.Run("missing ssid parameter", func(t *testing.T) {
manager := &Manager{
state: &NetworkState{},
}
conn := newMockNetConn()
req := Request{
ID: 123,
Method: "network.info",
Params: map[string]interface{}{},
}
handleGetNetworkInfo(conn, req, manager)
var resp models.Response[any]
err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, 123, resp.ID)
assert.Contains(t, resp.Error, "missing or invalid 'ssid' parameter")
})
}
func TestHandleRequest(t *testing.T) {
manager := &Manager{
state: &NetworkState{
NetworkStatus: StatusWiFi,
},
}
t.Run("unknown method", func(t *testing.T) {
conn := newMockNetConn()
req := Request{
ID: 123,
Method: "network.unknown",
}
HandleRequest(conn, req, manager)
var resp models.Response[any]
err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, 123, resp.ID)
assert.Contains(t, resp.Error, "unknown method")
})
t.Run("valid method - getState", func(t *testing.T) {
conn := newMockNetConn()
req := Request{
ID: 123,
Method: "network.getState",
}
HandleRequest(conn, req, manager)
var resp models.Response[NetworkState]
err := json.NewDecoder(conn.writeBuf).Decode(&resp)
require.NoError(t, err)
assert.Equal(t, 123, resp.ID)
assert.Empty(t, resp.Error)
})
}
func TestHandleSubscribe(t *testing.T) {
// This test is complex due to the streaming nature of subscriptions
// Better suited as an integration test
t.Skip("Subscription test requires connection lifecycle management - integration test needed")
}
func TestManager_Subscribe_Unsubscribe(t *testing.T) {
manager := &Manager{
state: &NetworkState{},
subscribers: make(map[string]chan NetworkState),
}
t.Run("subscribe creates channel", func(t *testing.T) {
ch := manager.Subscribe("client1")
assert.NotNil(t, ch)
assert.Len(t, manager.subscribers, 1)
})
t.Run("unsubscribe removes channel", func(t *testing.T) {
manager.Unsubscribe("client1")
assert.Len(t, manager.subscribers, 0)
})
t.Run("unsubscribe non-existent client is safe", func(t *testing.T) {
assert.NotPanics(t, func() {
manager.Unsubscribe("non-existent")
})
})
}

View File

@@ -0,0 +1,53 @@
package network
import "sort"
func frequencyToChannel(freq uint32) uint32 {
if freq >= 2412 && freq <= 2484 {
if freq == 2484 {
return 14
}
return (freq-2412)/5 + 1
}
if freq >= 5170 && freq <= 5825 {
return (freq-5170)/5 + 34
}
if freq >= 5955 && freq <= 7115 {
return (freq-5955)/5 + 1
}
return 0
}
func sortWiFiNetworks(networks []WiFiNetwork) {
sort.Slice(networks, func(i, j int) bool {
if networks[i].Connected && !networks[j].Connected {
return true
}
if !networks[i].Connected && networks[j].Connected {
return false
}
if networks[i].Saved && !networks[j].Saved {
return true
}
if !networks[i].Saved && networks[j].Saved {
return false
}
if !networks[i].Secured && networks[j].Secured {
if networks[i].Signal >= 50 {
return true
}
}
if networks[i].Secured && !networks[j].Secured {
if networks[j].Signal >= 50 {
return false
}
}
return networks[i].Signal > networks[j].Signal
})
}

View File

@@ -0,0 +1,530 @@
package network
import (
"fmt"
"sync"
"time"
"github.com/AvengeMedia/danklinux/internal/log"
)
func NewManager() (*Manager, error) {
detection, err := DetectNetworkStack()
if err != nil {
return nil, fmt.Errorf("failed to detect network stack: %w", err)
}
log.Infof("Network backend detection: %s", detection.ChosenReason)
var backend Backend
switch detection.Backend {
case BackendNetworkManager:
nm, err := NewNetworkManagerBackend()
if err != nil {
return nil, fmt.Errorf("failed to create NetworkManager backend: %w", err)
}
backend = nm
case BackendIwd:
iwd, err := NewIWDBackend()
if err != nil {
return nil, fmt.Errorf("failed to create iwd backend: %w", err)
}
backend = iwd
case BackendNetworkd:
if detection.HasIwd && !detection.HasNM {
wifi, err := NewIWDBackend()
if err != nil {
return nil, fmt.Errorf("failed to create iwd backend: %w", err)
}
l3, err := NewSystemdNetworkdBackend()
if err != nil {
return nil, fmt.Errorf("failed to create networkd backend: %w", err)
}
hybrid, err := NewHybridIwdNetworkdBackend(wifi, l3)
if err != nil {
return nil, fmt.Errorf("failed to create hybrid backend: %w", err)
}
backend = hybrid
} else {
nd, err := NewSystemdNetworkdBackend()
if err != nil {
return nil, fmt.Errorf("failed to create networkd backend: %w", err)
}
backend = nd
}
default:
return nil, fmt.Errorf("no supported network backend found: %s", detection.ChosenReason)
}
m := &Manager{
backend: backend,
state: &NetworkState{
NetworkStatus: StatusDisconnected,
Preference: PreferenceAuto,
WiFiNetworks: []WiFiNetwork{},
},
stateMutex: sync.RWMutex{},
subscribers: make(map[string]chan NetworkState),
subMutex: sync.RWMutex{},
stopChan: make(chan struct{}),
dirty: make(chan struct{}, 1),
credentialSubscribers: make(map[string]chan CredentialPrompt),
credSubMutex: sync.RWMutex{},
}
broker := NewSubscriptionBroker(m.broadcastCredentialPrompt)
if err := backend.SetPromptBroker(broker); err != nil {
return nil, fmt.Errorf("failed to set prompt broker: %w", err)
}
if err := backend.Initialize(); err != nil {
return nil, fmt.Errorf("failed to initialize backend: %w", err)
}
if err := m.syncStateFromBackend(); err != nil {
return nil, fmt.Errorf("failed to sync initial state: %w", err)
}
m.notifierWg.Add(1)
go m.notifier()
if err := backend.StartMonitoring(m.onBackendStateChange); err != nil {
m.Close()
return nil, fmt.Errorf("failed to start monitoring: %w", err)
}
return m, nil
}
func (m *Manager) syncStateFromBackend() error {
backendState, err := m.backend.GetCurrentState()
if err != nil {
return err
}
m.stateMutex.Lock()
m.state.Backend = backendState.Backend
m.state.NetworkStatus = backendState.NetworkStatus
m.state.EthernetIP = backendState.EthernetIP
m.state.EthernetDevice = backendState.EthernetDevice
m.state.EthernetConnected = backendState.EthernetConnected
m.state.EthernetConnectionUuid = backendState.EthernetConnectionUuid
m.state.WiFiIP = backendState.WiFiIP
m.state.WiFiDevice = backendState.WiFiDevice
m.state.WiFiConnected = backendState.WiFiConnected
m.state.WiFiEnabled = backendState.WiFiEnabled
m.state.WiFiSSID = backendState.WiFiSSID
m.state.WiFiBSSID = backendState.WiFiBSSID
m.state.WiFiSignal = backendState.WiFiSignal
m.state.WiFiNetworks = backendState.WiFiNetworks
m.state.WiredConnections = backendState.WiredConnections
m.state.VPNProfiles = backendState.VPNProfiles
m.state.VPNActive = backendState.VPNActive
m.state.IsConnecting = backendState.IsConnecting
m.state.ConnectingSSID = backendState.ConnectingSSID
m.state.LastError = backendState.LastError
m.stateMutex.Unlock()
return nil
}
func (m *Manager) onBackendStateChange() {
if err := m.syncStateFromBackend(); err != nil {
log.Errorf("failed to sync state from backend: %v", err)
}
m.notifySubscribers()
}
func signalChangeSignificant(old, new uint8) bool {
if old == 0 || new == 0 {
return true
}
diff := int(new) - int(old)
if diff < 0 {
diff = -diff
}
return diff >= 5
}
func (m *Manager) snapshotState() NetworkState {
m.stateMutex.RLock()
defer m.stateMutex.RUnlock()
s := *m.state
s.WiFiNetworks = append([]WiFiNetwork(nil), m.state.WiFiNetworks...)
s.WiredConnections = append([]WiredConnection(nil), m.state.WiredConnections...)
s.VPNProfiles = append([]VPNProfile(nil), m.state.VPNProfiles...)
s.VPNActive = append([]VPNActive(nil), m.state.VPNActive...)
return s
}
func stateChangedMeaningfully(old, new *NetworkState) bool {
if old.NetworkStatus != new.NetworkStatus {
return true
}
if old.Preference != new.Preference {
return true
}
if old.EthernetConnected != new.EthernetConnected {
return true
}
if old.EthernetIP != new.EthernetIP {
return true
}
if old.WiFiConnected != new.WiFiConnected {
return true
}
if old.WiFiEnabled != new.WiFiEnabled {
return true
}
if old.WiFiSSID != new.WiFiSSID {
return true
}
if old.WiFiBSSID != new.WiFiBSSID {
return true
}
if old.WiFiIP != new.WiFiIP {
return true
}
if !signalChangeSignificant(old.WiFiSignal, new.WiFiSignal) {
if old.WiFiSignal != new.WiFiSignal {
return false
}
} else if old.WiFiSignal != new.WiFiSignal {
return true
}
if old.IsConnecting != new.IsConnecting {
return true
}
if old.ConnectingSSID != new.ConnectingSSID {
return true
}
if old.LastError != new.LastError {
return true
}
if len(old.WiFiNetworks) != len(new.WiFiNetworks) {
return true
}
if len(old.WiredConnections) != len(new.WiredConnections) {
return true
}
for i := range old.WiFiNetworks {
oldNet := &old.WiFiNetworks[i]
newNet := &new.WiFiNetworks[i]
if oldNet.SSID != newNet.SSID {
return true
}
if oldNet.Connected != newNet.Connected {
return true
}
if oldNet.Saved != newNet.Saved {
return true
}
if oldNet.Autoconnect != newNet.Autoconnect {
return true
}
}
for i := range old.WiredConnections {
oldNet := &old.WiredConnections[i]
newNet := &new.WiredConnections[i]
if oldNet.ID != newNet.ID {
return true
}
if oldNet.IsActive != newNet.IsActive {
return true
}
}
// Check VPN profiles count
if len(old.VPNProfiles) != len(new.VPNProfiles) {
return true
}
// Check active VPN connections count or state
if len(old.VPNActive) != len(new.VPNActive) {
return true
}
// Check if any active VPN changed
for i := range old.VPNActive {
oldVPN := &old.VPNActive[i]
newVPN := &new.VPNActive[i]
if oldVPN.UUID != newVPN.UUID {
return true
}
if oldVPN.State != newVPN.State {
return true
}
}
return false
}
func (m *Manager) GetState() NetworkState {
return m.snapshotState()
}
func (m *Manager) Subscribe(id string) chan NetworkState {
ch := make(chan NetworkState, 64)
m.subMutex.Lock()
m.subscribers[id] = ch
m.subMutex.Unlock()
return ch
}
func (m *Manager) Unsubscribe(id string) {
m.subMutex.Lock()
if ch, ok := m.subscribers[id]; ok {
close(ch)
delete(m.subscribers, id)
}
m.subMutex.Unlock()
}
func (m *Manager) SubscribeCredentials(id string) chan CredentialPrompt {
ch := make(chan CredentialPrompt, 16)
m.credSubMutex.Lock()
m.credentialSubscribers[id] = ch
m.credSubMutex.Unlock()
return ch
}
func (m *Manager) UnsubscribeCredentials(id string) {
m.credSubMutex.Lock()
if ch, ok := m.credentialSubscribers[id]; ok {
close(ch)
delete(m.credentialSubscribers, id)
}
m.credSubMutex.Unlock()
}
func (m *Manager) broadcastCredentialPrompt(prompt CredentialPrompt) {
m.credSubMutex.RLock()
defer m.credSubMutex.RUnlock()
for _, ch := range m.credentialSubscribers {
select {
case ch <- prompt:
default:
}
}
}
func (m *Manager) notifier() {
defer m.notifierWg.Done()
const minGap = 100 * time.Millisecond
timer := time.NewTimer(minGap)
timer.Stop()
var pending bool
for {
select {
case <-m.stopChan:
timer.Stop()
return
case <-m.dirty:
if pending {
continue
}
pending = true
timer.Reset(minGap)
case <-timer.C:
if !pending {
continue
}
m.subMutex.RLock()
if len(m.subscribers) == 0 {
m.subMutex.RUnlock()
pending = false
continue
}
currentState := m.snapshotState()
if m.lastNotifiedState != nil && !stateChangedMeaningfully(m.lastNotifiedState, &currentState) {
m.subMutex.RUnlock()
pending = false
continue
}
for _, ch := range m.subscribers {
select {
case ch <- currentState:
default:
}
}
m.subMutex.RUnlock()
stateCopy := currentState
m.lastNotifiedState = &stateCopy
pending = false
}
}
}
func (m *Manager) notifySubscribers() {
select {
case m.dirty <- struct{}{}:
default:
}
}
func (m *Manager) SetPromptBroker(broker PromptBroker) error {
return m.backend.SetPromptBroker(broker)
}
func (m *Manager) SubmitCredentials(token string, secrets map[string]string, save bool) error {
return m.backend.SubmitCredentials(token, secrets, save)
}
func (m *Manager) CancelCredentials(token string) error {
return m.backend.CancelCredentials(token)
}
func (m *Manager) GetPromptBroker() PromptBroker {
return m.backend.GetPromptBroker()
}
func (m *Manager) Close() {
close(m.stopChan)
m.notifierWg.Wait()
if m.backend != nil {
m.backend.Close()
}
m.subMutex.Lock()
for _, ch := range m.subscribers {
close(ch)
}
m.subscribers = make(map[string]chan NetworkState)
m.subMutex.Unlock()
}
func (m *Manager) ScanWiFi() error {
return m.backend.ScanWiFi()
}
func (m *Manager) GetWiFiNetworks() []WiFiNetwork {
m.stateMutex.RLock()
defer m.stateMutex.RUnlock()
networks := make([]WiFiNetwork, len(m.state.WiFiNetworks))
copy(networks, m.state.WiFiNetworks)
return networks
}
func (m *Manager) GetNetworkInfo(ssid string) (*WiFiNetwork, error) {
m.stateMutex.RLock()
defer m.stateMutex.RUnlock()
for _, network := range m.state.WiFiNetworks {
if network.SSID == ssid {
return &network, nil
}
}
return nil, fmt.Errorf("network not found: %s", ssid)
}
func (m *Manager) GetNetworkInfoDetailed(ssid string) (*NetworkInfoResponse, error) {
return m.backend.GetWiFiNetworkDetails(ssid)
}
func (m *Manager) ToggleWiFi() error {
enabled, err := m.backend.GetWiFiEnabled()
if err != nil {
return fmt.Errorf("failed to get WiFi state: %w", err)
}
err = m.backend.SetWiFiEnabled(!enabled)
if err != nil {
return fmt.Errorf("failed to toggle WiFi: %w", err)
}
return nil
}
func (m *Manager) EnableWiFi() error {
err := m.backend.SetWiFiEnabled(true)
if err != nil {
return fmt.Errorf("failed to enable WiFi: %w", err)
}
return nil
}
func (m *Manager) DisableWiFi() error {
err := m.backend.SetWiFiEnabled(false)
if err != nil {
return fmt.Errorf("failed to disable WiFi: %w", err)
}
return nil
}
func (m *Manager) ConnectWiFi(req ConnectionRequest) error {
return m.backend.ConnectWiFi(req)
}
func (m *Manager) DisconnectWiFi() error {
return m.backend.DisconnectWiFi()
}
func (m *Manager) ForgetWiFiNetwork(ssid string) error {
return m.backend.ForgetWiFiNetwork(ssid)
}
func (m *Manager) GetWiredConfigs() []WiredConnection {
m.stateMutex.RLock()
defer m.stateMutex.RUnlock()
configs := make([]WiredConnection, len(m.state.WiredConnections))
copy(configs, m.state.WiredConnections)
return configs
}
func (m *Manager) GetWiredNetworkInfoDetailed(uuid string) (*WiredNetworkInfoResponse, error) {
return m.backend.GetWiredNetworkDetails(uuid)
}
func (m *Manager) ConnectEthernet() error {
return m.backend.ConnectEthernet()
}
func (m *Manager) DisconnectEthernet() error {
return m.backend.DisconnectEthernet()
}
func (m *Manager) activateConnection(uuid string) error {
return m.backend.ActivateWiredConnection(uuid)
}
func (m *Manager) ListVPNProfiles() ([]VPNProfile, error) {
return m.backend.ListVPNProfiles()
}
func (m *Manager) ListActiveVPN() ([]VPNActive, error) {
return m.backend.ListActiveVPN()
}
func (m *Manager) ConnectVPN(uuidOrName string, singleActive bool) error {
return m.backend.ConnectVPN(uuidOrName, singleActive)
}
func (m *Manager) DisconnectVPN(uuidOrName string) error {
return m.backend.DisconnectVPN(uuidOrName)
}
func (m *Manager) DisconnectAllVPN() error {
return m.backend.DisconnectAllVPN()
}
func (m *Manager) ClearVPNCredentials(uuidOrName string) error {
return m.backend.ClearVPNCredentials(uuidOrName)
}
func (m *Manager) SetWiFiAutoconnect(ssid string, autoconnect bool) error {
return m.backend.SetWiFiAutoconnect(ssid, autoconnect)
}

View File

@@ -0,0 +1,209 @@
package network
import (
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestManager_GetState(t *testing.T) {
state := &NetworkState{
NetworkStatus: StatusWiFi,
WiFiSSID: "TestNetwork",
WiFiConnected: true,
}
manager := &Manager{
state: state,
stateMutex: sync.RWMutex{},
}
result := manager.GetState()
assert.Equal(t, StatusWiFi, result.NetworkStatus)
assert.Equal(t, "TestNetwork", result.WiFiSSID)
assert.True(t, result.WiFiConnected)
}
func TestManager_NotifySubscribers(t *testing.T) {
manager := &Manager{
state: &NetworkState{
NetworkStatus: StatusWiFi,
},
stateMutex: sync.RWMutex{},
subscribers: make(map[string]chan NetworkState),
subMutex: sync.RWMutex{},
stopChan: make(chan struct{}),
dirty: make(chan struct{}, 1),
}
manager.notifierWg.Add(1)
go manager.notifier()
ch := make(chan NetworkState, 10)
manager.subMutex.Lock()
manager.subscribers["test-client"] = ch
manager.subMutex.Unlock()
manager.notifySubscribers()
select {
case state := <-ch:
assert.Equal(t, StatusWiFi, state.NetworkStatus)
case <-time.After(200 * time.Millisecond):
t.Fatal("did not receive state update")
}
close(manager.stopChan)
manager.notifierWg.Wait()
}
func TestManager_NotifySubscribers_Debounce(t *testing.T) {
manager := &Manager{
state: &NetworkState{
NetworkStatus: StatusWiFi,
},
stateMutex: sync.RWMutex{},
subscribers: make(map[string]chan NetworkState),
subMutex: sync.RWMutex{},
stopChan: make(chan struct{}),
dirty: make(chan struct{}, 1),
}
manager.notifierWg.Add(1)
go manager.notifier()
ch := make(chan NetworkState, 10)
manager.subMutex.Lock()
manager.subscribers["test-client"] = ch
manager.subMutex.Unlock()
manager.notifySubscribers()
manager.notifySubscribers()
manager.notifySubscribers()
receivedCount := 0
timeout := time.After(200 * time.Millisecond)
for {
select {
case <-ch:
receivedCount++
case <-timeout:
assert.Equal(t, 1, receivedCount, "should receive exactly one debounced update")
close(manager.stopChan)
manager.notifierWg.Wait()
return
}
}
}
func TestManager_Close(t *testing.T) {
manager := &Manager{
state: &NetworkState{},
stateMutex: sync.RWMutex{},
subscribers: make(map[string]chan NetworkState),
subMutex: sync.RWMutex{},
stopChan: make(chan struct{}),
}
ch1 := make(chan NetworkState, 1)
ch2 := make(chan NetworkState, 1)
manager.subMutex.Lock()
manager.subscribers["client1"] = ch1
manager.subscribers["client2"] = ch2
manager.subMutex.Unlock()
manager.Close()
select {
case <-manager.stopChan:
case <-time.After(100 * time.Millisecond):
t.Fatal("stopChan not closed")
}
_, ok1 := <-ch1
_, ok2 := <-ch2
assert.False(t, ok1, "ch1 should be closed")
assert.False(t, ok2, "ch2 should be closed")
assert.Len(t, manager.subscribers, 0)
}
func TestManager_Subscribe(t *testing.T) {
manager := &Manager{
state: &NetworkState{},
subscribers: make(map[string]chan NetworkState),
subMutex: sync.RWMutex{},
}
ch := manager.Subscribe("test-client")
assert.NotNil(t, ch)
assert.Equal(t, 64, cap(ch))
manager.subMutex.RLock()
_, exists := manager.subscribers["test-client"]
manager.subMutex.RUnlock()
assert.True(t, exists)
}
func TestManager_Unsubscribe(t *testing.T) {
manager := &Manager{
state: &NetworkState{},
subscribers: make(map[string]chan NetworkState),
subMutex: sync.RWMutex{},
}
ch := manager.Subscribe("test-client")
manager.Unsubscribe("test-client")
_, ok := <-ch
assert.False(t, ok)
manager.subMutex.RLock()
_, exists := manager.subscribers["test-client"]
manager.subMutex.RUnlock()
assert.False(t, exists)
}
func TestNewManager(t *testing.T) {
t.Run("attempts to create manager", func(t *testing.T) {
manager, err := NewManager()
if err != nil {
assert.Nil(t, manager)
} else {
assert.NotNil(t, manager)
assert.NotNil(t, manager.state)
assert.NotNil(t, manager.subscribers)
assert.NotNil(t, manager.stopChan)
manager.Close()
}
})
}
func TestManager_GetState_ThreadSafe(t *testing.T) {
manager := &Manager{
state: &NetworkState{
NetworkStatus: StatusWiFi,
WiFiSSID: "TestNetwork",
},
stateMutex: sync.RWMutex{},
}
done := make(chan bool)
for i := 0; i < 10; i++ {
go func() {
state := manager.GetState()
assert.Equal(t, StatusWiFi, state.NetworkStatus)
done <- true
}()
}
for i := 0; i < 10; i++ {
select {
case <-done:
case <-time.After(1 * time.Second):
t.Fatal("timeout waiting for goroutines")
}
}
}

View File

@@ -0,0 +1 @@
package network

View File

@@ -0,0 +1,138 @@
package network
import (
"fmt"
"time"
"github.com/Wifx/gonetworkmanager/v2"
)
func (m *Manager) SetConnectionPreference(pref ConnectionPreference) error {
switch pref {
case PreferenceWiFi, PreferenceEthernet, PreferenceAuto:
default:
return fmt.Errorf("invalid preference: %s", pref)
}
m.stateMutex.Lock()
m.state.Preference = pref
m.stateMutex.Unlock()
if _, ok := m.backend.(*NetworkManagerBackend); !ok {
m.notifySubscribers()
return nil
}
switch pref {
case PreferenceWiFi:
return m.prioritizeWiFi()
case PreferenceEthernet:
return m.prioritizeEthernet()
case PreferenceAuto:
return m.balancePriorities()
}
return nil
}
func (m *Manager) prioritizeWiFi() error {
if err := m.setConnectionMetrics("802-11-wireless", 50); err != nil {
return err
}
if err := m.setConnectionMetrics("802-3-ethernet", 100); err != nil {
return err
}
m.notifySubscribers()
return nil
}
func (m *Manager) prioritizeEthernet() error {
if err := m.setConnectionMetrics("802-3-ethernet", 50); err != nil {
return err
}
if err := m.setConnectionMetrics("802-11-wireless", 100); err != nil {
return err
}
m.notifySubscribers()
return nil
}
func (m *Manager) balancePriorities() error {
if err := m.setConnectionMetrics("802-3-ethernet", 50); err != nil {
return err
}
if err := m.setConnectionMetrics("802-11-wireless", 50); err != nil {
return err
}
m.notifySubscribers()
return nil
}
func (m *Manager) setConnectionMetrics(connType string, metric uint32) error {
settingsMgr, err := gonetworkmanager.NewSettings()
if err != nil {
return fmt.Errorf("failed to get settings: %w", err)
}
connections, err := settingsMgr.ListConnections()
if err != nil {
return fmt.Errorf("failed to get connections: %w", err)
}
for _, conn := range connections {
connSettings, err := conn.GetSettings()
if err != nil {
continue
}
if connMeta, ok := connSettings["connection"]; ok {
if cType, ok := connMeta["type"].(string); ok && cType == connType {
if connSettings["ipv4"] == nil {
connSettings["ipv4"] = make(map[string]interface{})
}
if ipv4Map := connSettings["ipv4"]; ipv4Map != nil {
ipv4Map["route-metric"] = int64(metric)
}
if connSettings["ipv6"] == nil {
connSettings["ipv6"] = make(map[string]interface{})
}
if ipv6Map := connSettings["ipv6"]; ipv6Map != nil {
ipv6Map["route-metric"] = int64(metric)
}
err = conn.Update(connSettings)
if err != nil {
continue
}
}
}
}
return nil
}
func (m *Manager) GetConnectionPreference() ConnectionPreference {
m.stateMutex.RLock()
defer m.stateMutex.RUnlock()
return m.state.Preference
}
func (m *Manager) WasRecentlyFailed(ssid string) bool {
if nm, ok := m.backend.(*NetworkManagerBackend); ok {
nm.failedMutex.RLock()
defer nm.failedMutex.RUnlock()
if nm.lastFailedSSID == ssid {
elapsed := time.Now().Unix() - nm.lastFailedTime
return elapsed < 10
}
}
return false
}

View File

@@ -0,0 +1,50 @@
package network
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestManager_SetConnectionPreference(t *testing.T) {
t.Run("invalid preference", func(t *testing.T) {
manager := &Manager{
state: &NetworkState{
Preference: PreferenceAuto,
},
}
err := manager.SetConnectionPreference(ConnectionPreference("invalid"))
assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid preference")
})
}
func TestManager_GetConnectionPreference(t *testing.T) {
tests := []struct {
name string
preference ConnectionPreference
}{
{"auto", PreferenceAuto},
{"wifi", PreferenceWiFi},
{"ethernet", PreferenceEthernet},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
manager := &Manager{
state: &NetworkState{
Preference: tt.preference,
},
}
result := manager.GetConnectionPreference()
assert.Equal(t, tt.preference, result)
})
}
}
// Note: Full testing of priority operations would require mocking NetworkManager
// D-Bus interfaces. The tests above cover the basic logic and error handling.
// Integration tests would be needed for complete coverage of network connection
// priority updates and reactivation.

View File

@@ -0,0 +1,146 @@
package network
import (
"context"
"fmt"
"sync"
"github.com/AvengeMedia/danklinux/internal/errdefs"
"github.com/AvengeMedia/danklinux/internal/log"
)
type SubscriptionBroker struct {
mu sync.RWMutex
pending map[string]chan PromptReply
requests map[string]PromptRequest
pathSettingToToken map[string]string
broadcastPrompt func(CredentialPrompt)
}
func NewSubscriptionBroker(broadcastPrompt func(CredentialPrompt)) PromptBroker {
return &SubscriptionBroker{
pending: make(map[string]chan PromptReply),
requests: make(map[string]PromptRequest),
pathSettingToToken: make(map[string]string),
broadcastPrompt: broadcastPrompt,
}
}
func (b *SubscriptionBroker) Ask(ctx context.Context, req PromptRequest) (string, error) {
pathSettingKey := fmt.Sprintf("%s:%s", req.ConnectionPath, req.SettingName)
b.mu.Lock()
existingToken, alreadyPending := b.pathSettingToToken[pathSettingKey]
b.mu.Unlock()
if alreadyPending {
log.Infof("[SubscriptionBroker] Duplicate prompt for %s, returning existing token", pathSettingKey)
return existingToken, nil
}
token, err := generateToken()
if err != nil {
return "", err
}
replyChan := make(chan PromptReply, 1)
b.mu.Lock()
b.pending[token] = replyChan
b.requests[token] = req
b.pathSettingToToken[pathSettingKey] = token
b.mu.Unlock()
if b.broadcastPrompt != nil {
prompt := CredentialPrompt{
Token: token,
Name: req.Name,
SSID: req.SSID,
ConnType: req.ConnType,
VpnService: req.VpnService,
Setting: req.SettingName,
Fields: req.Fields,
Hints: req.Hints,
Reason: req.Reason,
ConnectionId: req.ConnectionId,
ConnectionUuid: req.ConnectionUuid,
}
b.broadcastPrompt(prompt)
}
return token, nil
}
func (b *SubscriptionBroker) Wait(ctx context.Context, token string) (PromptReply, error) {
b.mu.RLock()
replyChan, exists := b.pending[token]
b.mu.RUnlock()
if !exists {
return PromptReply{}, fmt.Errorf("unknown token: %s", token)
}
select {
case <-ctx.Done():
b.cleanup(token)
return PromptReply{}, errdefs.ErrSecretPromptTimeout
case reply := <-replyChan:
b.cleanup(token)
if reply.Cancel {
return reply, errdefs.ErrSecretPromptCancelled
}
return reply, nil
}
}
func (b *SubscriptionBroker) Resolve(token string, reply PromptReply) error {
b.mu.RLock()
replyChan, exists := b.pending[token]
b.mu.RUnlock()
if !exists {
log.Warnf("[SubscriptionBroker] Resolve: unknown or expired token: %s", token)
return fmt.Errorf("unknown or expired token: %s", token)
}
select {
case replyChan <- reply:
return nil
default:
log.Warnf("[SubscriptionBroker] Resolve: failed to deliver reply for token %s (channel full or closed)", token)
return fmt.Errorf("failed to deliver reply for token: %s", token)
}
}
func (b *SubscriptionBroker) cleanup(token string) {
b.mu.Lock()
defer b.mu.Unlock()
if req, exists := b.requests[token]; exists {
pathSettingKey := fmt.Sprintf("%s:%s", req.ConnectionPath, req.SettingName)
delete(b.pathSettingToToken, pathSettingKey)
}
delete(b.pending, token)
delete(b.requests, token)
}
func (b *SubscriptionBroker) Cancel(path string, setting string) error {
pathSettingKey := fmt.Sprintf("%s:%s", path, setting)
b.mu.Lock()
token, exists := b.pathSettingToToken[pathSettingKey]
b.mu.Unlock()
if !exists {
log.Infof("[SubscriptionBroker] Cancel: no pending prompt for %s", pathSettingKey)
return nil
}
log.Infof("[SubscriptionBroker] Cancelling prompt for %s (token=%s)", pathSettingKey, token)
reply := PromptReply{
Cancel: true,
}
return b.Resolve(token, reply)
}

View File

@@ -0,0 +1,15 @@
package network
// NewTestManager creates a Manager for testing with a provided backend
func NewTestManager(backend Backend, state *NetworkState) *Manager {
if state == nil {
state = &NetworkState{}
}
return &Manager{
backend: backend,
state: state,
subscribers: make(map[string]chan NetworkState),
stopChan: make(chan struct{}),
dirty: make(chan struct{}, 1),
}
}

View File

@@ -0,0 +1,190 @@
package network
import (
"sync"
"github.com/godbus/dbus/v5"
)
type NetworkStatus string
const (
StatusDisconnected NetworkStatus = "disconnected"
StatusEthernet NetworkStatus = "ethernet"
StatusWiFi NetworkStatus = "wifi"
StatusVPN NetworkStatus = "vpn"
)
type ConnectionPreference string
const (
PreferenceAuto ConnectionPreference = "auto"
PreferenceWiFi ConnectionPreference = "wifi"
PreferenceEthernet ConnectionPreference = "ethernet"
)
type WiFiNetwork struct {
SSID string `json:"ssid"`
BSSID string `json:"bssid"`
Signal uint8 `json:"signal"`
Secured bool `json:"secured"`
Enterprise bool `json:"enterprise"`
Connected bool `json:"connected"`
Saved bool `json:"saved"`
Autoconnect bool `json:"autoconnect"`
Frequency uint32 `json:"frequency"`
Mode string `json:"mode"`
Rate uint32 `json:"rate"`
Channel uint32 `json:"channel"`
}
type VPNProfile struct {
Name string `json:"name"`
UUID string `json:"uuid"`
Type string `json:"type"`
ServiceType string `json:"serviceType"`
}
type VPNActive struct {
Name string `json:"name"`
UUID string `json:"uuid"`
Device string `json:"device,omitempty"`
State string `json:"state,omitempty"`
Type string `json:"type"`
Plugin string `json:"serviceType"`
}
type VPNState struct {
Profiles []VPNProfile `json:"profiles"`
Active []VPNActive `json:"activeConnections"`
}
type NetworkState struct {
Backend string `json:"backend"`
NetworkStatus NetworkStatus `json:"networkStatus"`
Preference ConnectionPreference `json:"preference"`
EthernetIP string `json:"ethernetIP"`
EthernetDevice string `json:"ethernetDevice"`
EthernetConnected bool `json:"ethernetConnected"`
EthernetConnectionUuid string `json:"ethernetConnectionUuid"`
WiFiIP string `json:"wifiIP"`
WiFiDevice string `json:"wifiDevice"`
WiFiConnected bool `json:"wifiConnected"`
WiFiEnabled bool `json:"wifiEnabled"`
WiFiSSID string `json:"wifiSSID"`
WiFiBSSID string `json:"wifiBSSID"`
WiFiSignal uint8 `json:"wifiSignal"`
WiFiNetworks []WiFiNetwork `json:"wifiNetworks"`
WiredConnections []WiredConnection `json:"wiredConnections"`
VPNProfiles []VPNProfile `json:"vpnProfiles"`
VPNActive []VPNActive `json:"vpnActive"`
IsConnecting bool `json:"isConnecting"`
ConnectingSSID string `json:"connectingSSID"`
LastError string `json:"lastError"`
}
type ConnectionRequest struct {
SSID string `json:"ssid"`
Password string `json:"password,omitempty"`
Username string `json:"username,omitempty"`
AnonymousIdentity string `json:"anonymousIdentity,omitempty"`
DomainSuffixMatch string `json:"domainSuffixMatch,omitempty"`
Interactive bool `json:"interactive,omitempty"`
}
type WiredConnection struct {
Path dbus.ObjectPath `json:"path"`
ID string `json:"id"`
UUID string `json:"uuid"`
Type string `json:"type"`
IsActive bool `json:"isActive"`
}
type PriorityUpdate struct {
Preference ConnectionPreference `json:"preference"`
}
type Manager struct {
backend Backend
state *NetworkState
stateMutex sync.RWMutex
subscribers map[string]chan NetworkState
subMutex sync.RWMutex
stopChan chan struct{}
dirty chan struct{}
notifierWg sync.WaitGroup
lastNotifiedState *NetworkState
credentialSubscribers map[string]chan CredentialPrompt
credSubMutex sync.RWMutex
}
type EventType string
const (
EventStateChanged EventType = "state_changed"
EventNetworksUpdated EventType = "networks_updated"
EventConnecting EventType = "connecting"
EventConnected EventType = "connected"
EventDisconnected EventType = "disconnected"
EventError EventType = "error"
)
type NetworkEvent struct {
Type EventType `json:"type"`
Data NetworkState `json:"data"`
}
type PromptRequest struct {
Name string `json:"name"`
SSID string `json:"ssid"`
ConnType string `json:"connType"`
VpnService string `json:"vpnService"`
SettingName string `json:"setting"`
Fields []string `json:"fields"`
Hints []string `json:"hints"`
Reason string `json:"reason"`
ConnectionId string `json:"connectionId"`
ConnectionUuid string `json:"connectionUuid"`
ConnectionPath string `json:"connectionPath"`
}
type PromptReply struct {
Secrets map[string]string `json:"secrets"`
Save bool `json:"save"`
Cancel bool `json:"cancel"`
}
type CredentialPrompt struct {
Token string `json:"token"`
Name string `json:"name"`
SSID string `json:"ssid"`
ConnType string `json:"connType"`
VpnService string `json:"vpnService"`
Setting string `json:"setting"`
Fields []string `json:"fields"`
Hints []string `json:"hints"`
Reason string `json:"reason"`
ConnectionId string `json:"connectionId"`
ConnectionUuid string `json:"connectionUuid"`
}
type NetworkInfoResponse struct {
SSID string `json:"ssid"`
Bands []WiFiNetwork `json:"bands"`
}
type WiredNetworkInfoResponse struct {
UUID string `json:"uuid"`
IFace string `json:"iface"`
Driver string `json:"driver"`
HwAddr string `json:"hwAddr"`
Speed string `json:"speed"`
IPv4 WiredIPConfig `json:"IPv4s"`
IPv6 WiredIPConfig `json:"IPv6s"`
}
type WiredIPConfig struct {
IPs []string `json:"ips"`
Gateway string `json:"gateway"`
DNS string `json:"dns"`
}

View File

@@ -0,0 +1,178 @@
package network
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNetworkStatus_Constants(t *testing.T) {
assert.Equal(t, NetworkStatus("disconnected"), StatusDisconnected)
assert.Equal(t, NetworkStatus("ethernet"), StatusEthernet)
assert.Equal(t, NetworkStatus("wifi"), StatusWiFi)
}
func TestConnectionPreference_Constants(t *testing.T) {
assert.Equal(t, ConnectionPreference("auto"), PreferenceAuto)
assert.Equal(t, ConnectionPreference("wifi"), PreferenceWiFi)
assert.Equal(t, ConnectionPreference("ethernet"), PreferenceEthernet)
}
func TestEventType_Constants(t *testing.T) {
assert.Equal(t, EventType("state_changed"), EventStateChanged)
assert.Equal(t, EventType("networks_updated"), EventNetworksUpdated)
assert.Equal(t, EventType("connecting"), EventConnecting)
assert.Equal(t, EventType("connected"), EventConnected)
assert.Equal(t, EventType("disconnected"), EventDisconnected)
assert.Equal(t, EventType("error"), EventError)
}
func TestWiFiNetwork_JSON(t *testing.T) {
network := WiFiNetwork{
SSID: "TestNetwork",
BSSID: "00:11:22:33:44:55",
Signal: 85,
Secured: true,
Enterprise: false,
Connected: true,
Saved: true,
Frequency: 2437,
Mode: "infrastructure",
Rate: 300,
Channel: 6,
}
data, err := json.Marshal(network)
require.NoError(t, err)
var decoded WiFiNetwork
err = json.Unmarshal(data, &decoded)
require.NoError(t, err)
assert.Equal(t, network.SSID, decoded.SSID)
assert.Equal(t, network.BSSID, decoded.BSSID)
assert.Equal(t, network.Signal, decoded.Signal)
assert.Equal(t, network.Secured, decoded.Secured)
assert.Equal(t, network.Enterprise, decoded.Enterprise)
assert.Equal(t, network.Connected, decoded.Connected)
assert.Equal(t, network.Saved, decoded.Saved)
assert.Equal(t, network.Frequency, decoded.Frequency)
assert.Equal(t, network.Mode, decoded.Mode)
assert.Equal(t, network.Rate, decoded.Rate)
assert.Equal(t, network.Channel, decoded.Channel)
}
func TestNetworkState_JSON(t *testing.T) {
state := NetworkState{
NetworkStatus: StatusWiFi,
Preference: PreferenceAuto,
EthernetIP: "192.168.1.100",
EthernetDevice: "eth0",
EthernetConnected: false,
WiFiIP: "192.168.1.101",
WiFiDevice: "wlan0",
WiFiConnected: true,
WiFiEnabled: true,
WiFiSSID: "TestNetwork",
WiFiBSSID: "00:11:22:33:44:55",
WiFiSignal: 85,
WiFiNetworks: []WiFiNetwork{
{SSID: "Network1", Signal: 90},
{SSID: "Network2", Signal: 60},
},
IsConnecting: false,
ConnectingSSID: "",
LastError: "",
}
data, err := json.Marshal(state)
require.NoError(t, err)
var decoded NetworkState
err = json.Unmarshal(data, &decoded)
require.NoError(t, err)
assert.Equal(t, state.NetworkStatus, decoded.NetworkStatus)
assert.Equal(t, state.Preference, decoded.Preference)
assert.Equal(t, state.WiFiIP, decoded.WiFiIP)
assert.Equal(t, state.WiFiSSID, decoded.WiFiSSID)
assert.Equal(t, len(state.WiFiNetworks), len(decoded.WiFiNetworks))
}
func TestConnectionRequest_JSON(t *testing.T) {
t.Run("with password", func(t *testing.T) {
req := ConnectionRequest{
SSID: "TestNetwork",
Password: "testpass123",
}
data, err := json.Marshal(req)
require.NoError(t, err)
var decoded ConnectionRequest
err = json.Unmarshal(data, &decoded)
require.NoError(t, err)
assert.Equal(t, req.SSID, decoded.SSID)
assert.Equal(t, req.Password, decoded.Password)
assert.Empty(t, decoded.Username)
})
t.Run("with username and password (enterprise)", func(t *testing.T) {
req := ConnectionRequest{
SSID: "EnterpriseNetwork",
Password: "testpass123",
Username: "testuser",
}
data, err := json.Marshal(req)
require.NoError(t, err)
var decoded ConnectionRequest
err = json.Unmarshal(data, &decoded)
require.NoError(t, err)
assert.Equal(t, req.SSID, decoded.SSID)
assert.Equal(t, req.Password, decoded.Password)
assert.Equal(t, req.Username, decoded.Username)
})
}
func TestPriorityUpdate_JSON(t *testing.T) {
update := PriorityUpdate{
Preference: PreferenceWiFi,
}
data, err := json.Marshal(update)
require.NoError(t, err)
var decoded PriorityUpdate
err = json.Unmarshal(data, &decoded)
require.NoError(t, err)
assert.Equal(t, update.Preference, decoded.Preference)
}
func TestNetworkEvent_JSON(t *testing.T) {
event := NetworkEvent{
Type: EventStateChanged,
Data: NetworkState{
NetworkStatus: StatusWiFi,
WiFiSSID: "TestNetwork",
WiFiConnected: true,
},
}
data, err := json.Marshal(event)
require.NoError(t, err)
var decoded NetworkEvent
err = json.Unmarshal(data, &decoded)
require.NoError(t, err)
assert.Equal(t, event.Type, decoded.Type)
assert.Equal(t, event.Data.NetworkStatus, decoded.Data.NetworkStatus)
assert.Equal(t, event.Data.WiFiSSID, decoded.Data.WiFiSSID)
}

View File

@@ -0,0 +1,148 @@
package network
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestFrequencyToChannel(t *testing.T) {
tests := []struct {
name string
frequency uint32
channel uint32
}{
{"2.4 GHz channel 1", 2412, 1},
{"2.4 GHz channel 6", 2437, 6},
{"2.4 GHz channel 11", 2462, 11},
{"2.4 GHz channel 14", 2484, 14},
{"5 GHz channel 36", 5180, 36},
{"5 GHz channel 40", 5200, 40},
{"5 GHz channel 165", 5825, 165},
{"6 GHz channel 1", 5955, 1},
{"6 GHz channel 233", 7115, 233},
{"Unknown frequency", 1000, 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := frequencyToChannel(tt.frequency)
assert.Equal(t, tt.channel, result)
})
}
}
func TestSortWiFiNetworks(t *testing.T) {
t.Run("connected network comes first", func(t *testing.T) {
networks := []WiFiNetwork{
{SSID: "Network1", Signal: 90, Connected: false},
{SSID: "Network2", Signal: 80, Connected: true},
{SSID: "Network3", Signal: 70, Connected: false},
}
sortWiFiNetworks(networks)
assert.Equal(t, "Network2", networks[0].SSID)
assert.True(t, networks[0].Connected)
})
t.Run("sorts by signal strength", func(t *testing.T) {
networks := []WiFiNetwork{
{SSID: "Weak", Signal: 40, Secured: true},
{SSID: "Strong", Signal: 90, Secured: true},
{SSID: "Medium", Signal: 60, Secured: true},
}
sortWiFiNetworks(networks)
assert.Equal(t, "Strong", networks[0].SSID)
assert.Equal(t, "Medium", networks[1].SSID)
assert.Equal(t, "Weak", networks[2].SSID)
})
t.Run("prioritizes open networks with good signal", func(t *testing.T) {
networks := []WiFiNetwork{
{SSID: "SecureWeak", Signal: 40, Secured: true},
{SSID: "OpenStrong", Signal: 60, Secured: false},
{SSID: "SecureStrong", Signal: 90, Secured: true},
}
sortWiFiNetworks(networks)
assert.Equal(t, "OpenStrong", networks[0].SSID)
openIdx := -1
weakSecureIdx := -1
for i, n := range networks {
if n.SSID == "OpenStrong" {
openIdx = i
}
if n.SSID == "SecureWeak" {
weakSecureIdx = i
}
}
assert.Less(t, openIdx, weakSecureIdx, "OpenStrong should come before SecureWeak")
})
t.Run("prioritizes saved networks after connected", func(t *testing.T) {
networks := []WiFiNetwork{
{SSID: "UnsavedStrong", Signal: 95, Saved: false},
{SSID: "SavedMedium", Signal: 60, Saved: true},
{SSID: "SavedWeak", Signal: 50, Saved: true},
{SSID: "UnsavedMedium", Signal: 70, Saved: false},
}
sortWiFiNetworks(networks)
assert.Equal(t, "SavedMedium", networks[0].SSID)
assert.Equal(t, "SavedWeak", networks[1].SSID)
assert.Equal(t, "UnsavedStrong", networks[2].SSID)
assert.Equal(t, "UnsavedMedium", networks[3].SSID)
})
}
func TestManager_GetWiFiNetworks(t *testing.T) {
manager := &Manager{
state: &NetworkState{
WiFiNetworks: []WiFiNetwork{
{SSID: "Network1", Signal: 90},
{SSID: "Network2", Signal: 80},
},
},
}
networks := manager.GetWiFiNetworks()
assert.Len(t, networks, 2)
assert.Equal(t, "Network1", networks[0].SSID)
assert.Equal(t, "Network2", networks[1].SSID)
networks[0].SSID = "Modified"
assert.Equal(t, "Network1", manager.state.WiFiNetworks[0].SSID)
}
func TestManager_GetNetworkInfo(t *testing.T) {
manager := &Manager{
state: &NetworkState{
WiFiNetworks: []WiFiNetwork{
{SSID: "Network1", Signal: 90, BSSID: "00:11:22:33:44:55"},
{SSID: "Network2", Signal: 80, BSSID: "AA:BB:CC:DD:EE:FF"},
},
},
}
t.Run("finds existing network", func(t *testing.T) {
network, err := manager.GetNetworkInfo("Network1")
assert.NoError(t, err)
assert.NotNil(t, network)
assert.Equal(t, "Network1", network.SSID)
assert.Equal(t, uint8(90), network.Signal)
})
t.Run("returns error for non-existent network", func(t *testing.T) {
network, err := manager.GetNetworkInfo("NonExistent")
assert.Error(t, err)
assert.Nil(t, network)
assert.Contains(t, err.Error(), "network not found")
})
}

View File

@@ -0,0 +1,23 @@
package network
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestManager_GetWiredConfigs(t *testing.T) {
manager := &Manager{
state: &NetworkState{
EthernetConnected: true,
WiredConnections: []WiredConnection{
{ID: "Test", IsActive: true},
},
},
}
configs := manager.GetWiredConfigs()
assert.Len(t, configs, 1)
assert.Equal(t, "Test", configs[0].ID)
}

Some files were not shown because too many files have changed in this diff Show More