code auth added

This commit is contained in:
partisan 2024-08-08 13:35:50 +02:00
parent f769f70ce7
commit faa20dc064
5 changed files with 127 additions and 145 deletions

18
init.go
View file

@ -11,15 +11,16 @@ import (
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"time"
"github.com/fsnotify/fsnotify" "github.com/fsnotify/fsnotify"
) )
// Configuration structure
type Config struct { type Config struct {
Port int Port int
AuthCode string AuthCode string
Peers []string Peers []string
PeerID string
OpenSearch OpenSearchConfig OpenSearch OpenSearchConfig
} }
@ -27,7 +28,6 @@ type OpenSearchConfig struct {
Domain string Domain string
} }
// Default configuration values
var defaultConfig = Config{ var defaultConfig = Config{
Port: 5000, Port: 5000,
OpenSearch: OpenSearchConfig{ OpenSearch: OpenSearchConfig{
@ -57,11 +57,21 @@ func main() {
saveConfig(config) saveConfig(config)
} }
// Initialize P2P
var nodeErr error
hostID, nodeErr = initP2P()
if nodeErr != nil {
log.Fatalf("Failed to initialize P2P: %v", nodeErr)
}
config.PeerID = hostID.String()
if len(config.Peers) > 0 { if len(config.Peers) > 0 {
go startNodeClient(config.Peers) time.Sleep(2 * time.Second) // Give some time for connections to establish
startElection() startElection()
} }
go startNodeClient()
runServer() runServer()
} }
@ -103,7 +113,7 @@ func createConfig() error {
fmt.Print("Do you want to connect to other nodes? (yes/no): ") fmt.Print("Do you want to connect to other nodes? (yes/no): ")
connectNodes, _ := reader.ReadString('\n') connectNodes, _ := reader.ReadString('\n')
if strings.TrimSpace(connectNodes) == "yes" { if strings.TrimSpace(connectNodes) == "yes" {
fmt.Println("Enter peer addresses (comma separated, e.g., http://localhost:5000,http://localhost:5001): ") fmt.Println("Enter peer addresses (comma separated, e.g., /ip4/127.0.0.1/tcp/5000,/ip4/127.0.0.1/tcp/5001): ")
peersStr, _ := reader.ReadString('\n') peersStr, _ := reader.ReadString('\n')
if peersStr != "\n" { if peersStr != "\n" {
config.Peers = strings.Split(strings.TrimSpace(peersStr), ",") config.Peers = strings.Split(strings.TrimSpace(peersStr), ",")

30
main.go
View file

@ -122,20 +122,19 @@ func parsePageParameter(pageStr string) int {
} }
func runServer() { func runServer() {
// http.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.Dir("static")))) http.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.Dir("static"))))
// http.HandleFunc("/", handleSearch) http.HandleFunc("/", handleSearch)
// http.HandleFunc("/search", handleSearch) http.HandleFunc("/search", handleSearch)
// http.HandleFunc("/img_proxy", handleImageProxy) http.HandleFunc("/img_proxy", handleImageProxy)
// http.HandleFunc("/settings", func(w http.ResponseWriter, r *http.Request) { http.HandleFunc("/node", handleNodeRequest)
// http.ServeFile(w, r, "templates/settings.html") http.HandleFunc("/settings", func(w http.ResponseWriter, r *http.Request) {
// }) http.ServeFile(w, r, "templates/settings.html")
// http.HandleFunc("/opensearch.xml", func(w http.ResponseWriter, r *http.Request) { })
// w.Header().Set("Content-Type", "application/opensearchdescription+xml") http.HandleFunc("/opensearch.xml", func(w http.ResponseWriter, r *http.Request) {
// http.ServeFile(w, r, "static/opensearch.xml") w.Header().Set("Content-Type", "application/opensearchdescription+xml")
// }) http.ServeFile(w, r, "static/opensearch.xml")
// initializeTorrentSites() })
initializeTorrentSites()
http.HandleFunc("/node", handleNodeRequest) // Handle node requests
config := loadConfig() config := loadConfig()
generateOpenSearchXML(config) generateOpenSearchXML(config)
@ -143,9 +142,6 @@ func runServer() {
fmt.Printf("Server is listening on http://localhost:%d\n", config.Port) fmt.Printf("Server is listening on http://localhost:%d\n", config.Port)
log.Fatal(http.ListenAndServe(fmt.Sprintf(":%d", config.Port), nil)) log.Fatal(http.ListenAndServe(fmt.Sprintf(":%d", config.Port), nil))
// Start node communication client
go startNodeClient(peers)
// Start automatic update checker // Start automatic update checker
go checkForUpdates() go checkForUpdates()
} }

View file

@ -24,7 +24,12 @@ func sendHeartbeats() {
return return
} }
for _, node := range peers { for _, node := range peers {
err := sendMessage(node, authCode, "heartbeat", authCode) msg := Message{
ID: hostID.Pretty(),
Type: "heartbeat",
Content: authCode,
}
err := sendMessage(node, msg)
if err != nil { if err != nil {
log.Printf("Error sending heartbeat to %s: %v", node, err) log.Printf("Error sending heartbeat to %s: %v", node, err)
} }
@ -55,7 +60,12 @@ func startElection() {
defer masterNodeMux.Unlock() defer masterNodeMux.Unlock()
for _, node := range peers { for _, node := range peers {
err := sendMessage(node, authCode, "election", authCode) msg := Message{
ID: hostID.Pretty(),
Type: "election",
Content: authCode,
}
err := sendMessage(node, msg)
if err != nil { if err != nil {
log.Printf("Error sending election message to %s: %v", node, err) log.Printf("Error sending election message to %s: %v", node, err)
} }

View file

@ -9,14 +9,19 @@ import (
// Function to sync updates across all nodes // Function to sync updates across all nodes
func nodeUpdateSync() { func nodeUpdateSync() {
fmt.Println("Syncing updates across all nodes...") fmt.Println("Syncing updates across all nodes...")
for _, peer := range peers { for _, peerAddr := range peers {
fmt.Printf("Notifying node %s about update...\n", peer) fmt.Printf("Notifying node %s about update...\n", peerAddr)
err := sendMessage(peer, authCode, "update", "Start update process") msg := Message{
ID: hostID.Pretty(),
Type: "update",
Content: "Start update process",
}
err := sendMessage(peerAddr, msg)
if err != nil { if err != nil {
log.Printf("Failed to notify node %s: %v\n", peer, err) log.Printf("Failed to notify node %s: %v\n", peerAddr, err)
continue continue
} }
fmt.Printf("Node %s notified. Waiting for it to update...\n", peer) fmt.Printf("Node %s notified. Waiting for it to update...\n", peerAddr)
time.Sleep(30 * time.Second) // Adjust sleep time as needed to allow for updates time.Sleep(30 * time.Second) // Adjust sleep time as needed to allow for updates
} }
fmt.Println("All nodes have been updated.") fmt.Println("All nodes have been updated.")

195
node.go
View file

@ -2,8 +2,7 @@ package main
import ( import (
"bytes" "bytes"
"crypto/sha256" "crypto/rand"
"encoding/hex"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
@ -11,6 +10,10 @@ import (
"net/http" "net/http"
"sync" "sync"
"time" "time"
libp2p "github.com/libp2p/go-libp2p"
crypto "github.com/libp2p/go-libp2p-core/crypto"
"github.com/libp2p/go-libp2p-core/peer"
) )
var ( var (
@ -18,6 +21,7 @@ var (
peers []string peers []string
authMutex sync.Mutex authMutex sync.Mutex
authenticated = make(map[string]bool) authenticated = make(map[string]bool)
hostID peer.ID
) )
type Message struct { type Message struct {
@ -35,74 +39,99 @@ type CrawlerConfig struct {
func loadNodeConfig() { func loadNodeConfig() {
config := loadConfig() config := loadConfig()
authCode = config.AuthCode // nuh uh authCode = config.AuthCode
peers = config.Peers peers = config.Peers
} }
func initP2P() (peer.ID, error) {
priv, _, err := crypto.GenerateKeyPairWithReader(crypto.Ed25519, 2048, rand.Reader)
if err != nil {
return "", fmt.Errorf("failed to generate key pair: %v", err)
}
h, err := libp2p.New(libp2p.Identity(priv))
if err != nil {
return "", fmt.Errorf("failed to create libp2p host: %v", err)
}
return h.ID(), nil
}
func sendMessage(serverAddr string, msg Message) error {
msgBytes, err := json.Marshal(msg)
if err != nil {
return fmt.Errorf("failed to marshal message: %v", err)
}
req, err := http.NewRequest("POST", serverAddr, bytes.NewBuffer(msgBytes))
if err != nil {
return fmt.Errorf("failed to create request: %v", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", authCode)
client := &http.Client{
Timeout: time.Second * 10,
}
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("failed to send request: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := ioutil.ReadAll(resp.Body)
return fmt.Errorf("server error: %s", body)
}
return nil
}
func handleNodeRequest(w http.ResponseWriter, r *http.Request) { func handleNodeRequest(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost { if r.Method != http.MethodPost {
http.Error(w, "Invalid request method", http.StatusMethodNotAllowed) http.Error(w, "Invalid request method", http.StatusMethodNotAllowed)
return return
} }
body, err := ioutil.ReadAll(r.Body) auth := r.Header.Get("Authorization")
if err != nil { if auth != authCode {
http.Error(w, "Error reading request body", http.StatusInternalServerError) http.Error(w, "Unauthorized", http.StatusUnauthorized)
return return
} }
defer r.Body.Close()
var msg Message var msg Message
if err := json.Unmarshal(body, &msg); err != nil { err := json.NewDecoder(r.Body).Decode(&msg)
http.Error(w, "Error parsing JSON", http.StatusBadRequest)
return
}
if !isAuthenticated(msg.ID) {
http.Error(w, "Authentication required", http.StatusUnauthorized)
return
}
interpretMessage(msg)
fmt.Fprintln(w, "Message received")
}
func handleAuth(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Invalid request method", http.StatusMethodNotAllowed)
return
}
body, err := ioutil.ReadAll(r.Body)
if err != nil { if err != nil {
http.Error(w, "Error reading request body", http.StatusInternalServerError) http.Error(w, "Error parsing JSON", http.StatusBadRequest)
return return
} }
defer r.Body.Close() defer r.Body.Close()
var authRequest CrawlerConfig log.Printf("Received message: %+v\n", msg)
if err := json.Unmarshal(body, &authRequest); err != nil { w.Write([]byte("Message received"))
http.Error(w, "Error parsing JSON", http.StatusBadRequest)
return
}
expectedCode := GenerateRegistrationCode(authRequest.Host, authRequest.Port, authCode) interpretMessage(msg)
if authRequest.AuthCode != expectedCode {
http.Error(w, "Invalid auth code", http.StatusUnauthorized)
return
}
authMutex.Lock()
authenticated[authRequest.ID] = true
authMutex.Unlock()
fmt.Fprintln(w, "Authenticated successfully")
} }
func isAuthenticated(id string) bool { func startNodeClient() {
authMutex.Lock() for {
defer authMutex.Unlock() for _, peerAddr := range peers {
return authenticated[id] msg := Message{
ID: hostID.Pretty(),
Type: "test",
Content: "This is a test message from the client node",
}
err := sendMessage(peerAddr, msg)
if err != nil {
log.Printf("Error sending message to %s: %v", peerAddr, err)
} else {
log.Println("Message sent successfully to", peerAddr)
}
}
time.Sleep(10 * time.Second)
}
} }
func interpretMessage(msg Message) { func interpretMessage(msg Message) {
@ -120,71 +149,3 @@ func interpretMessage(msg Message) {
fmt.Println("Received unknown message type:", msg.Type) fmt.Println("Received unknown message type:", msg.Type)
} }
} }
func sendMessage(address, id, msgType, content string) error {
msg := Message{
ID: id,
Type: msgType,
Content: content,
}
msgBytes, err := json.Marshal(msg)
if err != nil {
return err
}
req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/node", address), bytes.NewBuffer(msgBytes))
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := ioutil.ReadAll(resp.Body)
return fmt.Errorf("failed to send message: %s", body)
}
return nil
}
func startNodeClient(addresses []string) {
for _, address := range addresses {
go func(addr string) {
for {
err := sendMessage(addr, authCode, "test", "This is a test message")
if err != nil {
fmt.Println("Error sending test message to", addr, ":", err)
continue
}
time.Sleep(10 * time.Second)
}
}(address)
}
}
func GenerateRegistrationCode(host string, port int, authCode string) string {
data := fmt.Sprintf("%s:%d:%s", host, port, authCode)
hash := sha256.Sum256([]byte(data))
return hex.EncodeToString(hash[:])
}
func ParseRegistrationCode(code string, host string, port int, authCode string) (string, int, string, error) {
data := fmt.Sprintf("%s:%d:%s", host, port, authCode)
hash := sha256.Sum256([]byte(data))
expectedCode := hex.EncodeToString(hash[:])
log.Printf("Parsing registration code: %s", code)
log.Printf("Expected registration code: %s", expectedCode)
if expectedCode != code {
return "", 0, "", fmt.Errorf("invalid registration code")
}
return host, port, authCode, nil
}