Initial import
[samba] / source / lib / messages.c
1 /* 
2    Unix SMB/CIFS implementation.
3    Samba internal messaging functions
4    Copyright (C) Andrew Tridgell 2000
5    Copyright (C) 2001 by Martin Pool
6    Copyright (C) 2002 by Jeremy Allison
7    
8    This program is free software; you can redistribute it and/or modify
9    it under the terms of the GNU General Public License as published by
10    the Free Software Foundation; either version 2 of the License, or
11    (at your option) any later version.
12    
13    This program is distributed in the hope that it will be useful,
14    but WITHOUT ANY WARRANTY; without even the implied warranty of
15    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16    GNU General Public License for more details.
17    
18    You should have received a copy of the GNU General Public License
19    along with this program; if not, write to the Free Software
20    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
21 */
22
23 /**
24   @defgroup messages Internal messaging framework
25   @{
26   @file messages.c
27   
28   @brief  Module for internal messaging between Samba daemons. 
29
30    The idea is that if a part of Samba wants to do communication with
31    another Samba process then it will do a message_register() of a
32    dispatch function, and use message_send_pid() to send messages to
33    that process.
34
35    The dispatch function is given the pid of the sender, and it can
36    use that to reply by message_send_pid().  See ping_message() for a
37    simple example.
38
39    @caution Dispatch functions must be able to cope with incoming
40    messages on an *odd* byte boundary.
41
42    This system doesn't have any inherent size limitations but is not
43    very efficient for large messages or when messages are sent in very
44    quick succession.
45
46 */
47
48 #include "includes.h"
49
50 /* the locking database handle */
51 static TDB_CONTEXT *tdb;
52 static int received_signal;
53
54 /* change the message version with any incompatible changes in the protocol */
55 #define MESSAGE_VERSION 1
56
57 struct message_rec {
58         int msg_version;
59         int msg_type;
60         struct process_id dest;
61         struct process_id src;
62         size_t len;
63 };
64
65 /* we have a linked list of dispatch handlers */
66 static struct dispatch_fns {
67         struct dispatch_fns *next, *prev;
68         int msg_type;
69         void (*fn)(int msg_type, struct process_id pid, void *buf, size_t len);
70 } *dispatch_fns;
71
72 /****************************************************************************
73  Notifications come in as signals.
74 ****************************************************************************/
75
76 static void sig_usr1(void)
77 {
78         received_signal = 1;
79         sys_select_signal(SIGUSR1);
80 }
81
82 /****************************************************************************
83  A useful function for testing the message system.
84 ****************************************************************************/
85
86 static void ping_message(int msg_type, struct process_id src,
87                          void *buf, size_t len)
88 {
89         const char *msg = buf ? buf : "none";
90         DEBUG(1,("INFO: Received PING message from PID %s [%s]\n",
91                  procid_str_static(&src), msg));
92         message_send_pid(src, MSG_PONG, buf, len, True);
93 }
94
95 /****************************************************************************
96  Initialise the messaging functions. 
97 ****************************************************************************/
98
99 BOOL message_init(void)
100 {
101         if (tdb) return True;
102
103         tdb = tdb_open_log(lock_path("messages.tdb"), 
104                        0, TDB_CLEAR_IF_FIRST|TDB_DEFAULT, 
105                        O_RDWR|O_CREAT,0600);
106
107         if (!tdb) {
108                 DEBUG(0,("ERROR: Failed to initialise messages database\n"));
109                 return False;
110         }
111
112         CatchSignal(SIGUSR1, SIGNAL_CAST sig_usr1);
113
114         message_register(MSG_PING, ping_message);
115
116         /* Register some debugging related messages */
117
118         register_msg_pool_usage();
119         register_dmalloc_msgs();
120
121         return True;
122 }
123
124 /*******************************************************************
125  Form a static tdb key from a pid.
126 ******************************************************************/
127
128 static TDB_DATA message_key_pid(struct process_id pid)
129 {
130         static char key[20];
131         TDB_DATA kbuf;
132
133         slprintf(key, sizeof(key)-1, "PID/%s", procid_str_static(&pid));
134         
135         kbuf.dptr = (char *)key;
136         kbuf.dsize = strlen(key)+1;
137         return kbuf;
138 }
139
140 /****************************************************************************
141  Notify a process that it has a message. If the process doesn't exist 
142  then delete its record in the database.
143 ****************************************************************************/
144
145 static BOOL message_notify(struct process_id procid)
146 {
147         pid_t pid = procid.pid;
148         /*
149          * Doing kill with a non-positive pid causes messages to be
150          * sent to places we don't want.
151          */
152
153         SMB_ASSERT(pid > 0);
154
155         if (kill(pid, SIGUSR1) == -1) {
156                 if (errno == ESRCH) {
157                         DEBUG(2,("pid %d doesn't exist - deleting messages record\n", (int)pid));
158                         tdb_delete(tdb, message_key_pid(procid));
159                 } else {
160                         DEBUG(2,("message to process %d failed - %s\n", (int)pid, strerror(errno)));
161                 }
162                 return False;
163         }
164         return True;
165 }
166
167 /****************************************************************************
168  Send a message to a particular pid.
169 ****************************************************************************/
170
171 static BOOL message_send_pid_internal(struct process_id pid, int msg_type,
172                                       const void *buf, size_t len,
173                                       BOOL duplicates_allowed,
174                                       unsigned int timeout)
175 {
176         TDB_DATA kbuf;
177         TDB_DATA dbuf;
178         TDB_DATA old_dbuf;
179         struct message_rec rec;
180         char *ptr;
181         struct message_rec prec;
182
183         /*
184          * Doing kill with a non-positive pid causes messages to be
185          * sent to places we don't want.
186          */
187
188         SMB_ASSERT(procid_to_pid(&pid) > 0);
189
190         rec.msg_version = MESSAGE_VERSION;
191         rec.msg_type = msg_type;
192         rec.dest = pid;
193         rec.src = procid_self();
194         rec.len = len;
195
196         kbuf = message_key_pid(pid);
197
198         dbuf.dptr = (void *)SMB_MALLOC(len + sizeof(rec));
199         if (!dbuf.dptr)
200                 return False;
201
202         memcpy(dbuf.dptr, &rec, sizeof(rec));
203         if (len > 0)
204                 memcpy((void *)((char*)dbuf.dptr+sizeof(rec)), buf, len);
205
206         dbuf.dsize = len + sizeof(rec);
207
208         if (duplicates_allowed) {
209
210                 /* If duplicates are allowed we can just append the message and return. */
211
212                 /* lock the record for the destination */
213                 if (timeout) {
214                         if (tdb_chainlock_with_timeout(tdb, kbuf, timeout) == -1) {
215                                 DEBUG(0,("message_send_pid_internal: failed to get chainlock with timeout %ul.\n", timeout));
216                                 return False;
217                         }
218                 } else {
219                         if (tdb_chainlock(tdb, kbuf) == -1) {
220                                 DEBUG(0,("message_send_pid_internal: failed to get chainlock.\n"));
221                                 return False;
222                         }
223                 }       
224                 tdb_append(tdb, kbuf, dbuf);
225                 tdb_chainunlock(tdb, kbuf);
226
227                 SAFE_FREE(dbuf.dptr);
228                 errno = 0;                    /* paranoia */
229                 return message_notify(pid);
230         }
231
232         /* lock the record for the destination */
233         if (timeout) {
234                 if (tdb_chainlock_with_timeout(tdb, kbuf, timeout) == -1) {
235                         DEBUG(0,("message_send_pid_internal: failed to get chainlock with timeout %ul.\n", timeout));
236                         return False;
237                 }
238         } else {
239                 if (tdb_chainlock(tdb, kbuf) == -1) {
240                         DEBUG(0,("message_send_pid_internal: failed to get chainlock.\n"));
241                         return False;
242                 }
243         }       
244
245         old_dbuf = tdb_fetch(tdb, kbuf);
246
247         if (!old_dbuf.dptr) {
248                 /* its a new record */
249
250                 tdb_store(tdb, kbuf, dbuf, TDB_REPLACE);
251                 tdb_chainunlock(tdb, kbuf);
252
253                 SAFE_FREE(dbuf.dptr);
254                 errno = 0;                    /* paranoia */
255                 return message_notify(pid);
256         }
257
258         /* Not a new record. Check for duplicates. */
259
260         for(ptr = (char *)old_dbuf.dptr; ptr < old_dbuf.dptr + old_dbuf.dsize; ) {
261                 /*
262                  * First check if the message header matches, then, if it's a non-zero
263                  * sized message, check if the data matches. If so it's a duplicate and
264                  * we can discard it. JRA.
265                  */
266
267                 if (!memcmp(ptr, &rec, sizeof(rec))) {
268                         if (!len || (len && !memcmp( ptr + sizeof(rec), buf, len))) {
269                                 tdb_chainunlock(tdb, kbuf);
270                                 DEBUG(10,("message_send_pid_internal: discarding duplicate message.\n"));
271                                 SAFE_FREE(dbuf.dptr);
272                                 SAFE_FREE(old_dbuf.dptr);
273                                 return True;
274                         }
275                 }
276                 memcpy(&prec, ptr, sizeof(prec));
277                 ptr += sizeof(rec) + prec.len;
278         }
279
280         /* we're adding to an existing entry */
281
282         tdb_append(tdb, kbuf, dbuf);
283         tdb_chainunlock(tdb, kbuf);
284
285         SAFE_FREE(old_dbuf.dptr);
286         SAFE_FREE(dbuf.dptr);
287
288         errno = 0;                    /* paranoia */
289         return message_notify(pid);
290 }
291
292 /****************************************************************************
293  Send a message to a particular pid - no timeout.
294 ****************************************************************************/
295
296 BOOL message_send_pid(struct process_id pid, int msg_type, const void *buf, size_t len, BOOL duplicates_allowed)
297 {
298         return message_send_pid_internal(pid, msg_type, buf, len, duplicates_allowed, 0);
299 }
300
301 /****************************************************************************
302  Send a message to a particular pid, with timeout in seconds.
303 ****************************************************************************/
304
305 BOOL message_send_pid_with_timeout(struct process_id pid, int msg_type, const void *buf, size_t len,
306                 BOOL duplicates_allowed, unsigned int timeout)
307 {
308         return message_send_pid_internal(pid, msg_type, buf, len, duplicates_allowed, timeout);
309 }
310
311 /****************************************************************************
312  Count the messages pending for a particular pid. Expensive....
313 ****************************************************************************/
314
315 unsigned int messages_pending_for_pid(struct process_id pid)
316 {
317         TDB_DATA kbuf;
318         TDB_DATA dbuf;
319         char *buf;
320         unsigned int message_count = 0;
321
322         kbuf = message_key_pid(pid);
323
324         dbuf = tdb_fetch(tdb, kbuf);
325         if (dbuf.dptr == NULL || dbuf.dsize == 0) {
326                 SAFE_FREE(dbuf.dptr);
327                 return 0;
328         }
329
330         for (buf = dbuf.dptr; dbuf.dsize > sizeof(struct message_rec);) {
331                 struct message_rec rec;
332                 memcpy(&rec, buf, sizeof(rec));
333                 buf += (sizeof(rec) + rec.len);
334                 dbuf.dsize -= (sizeof(rec) + rec.len);
335                 message_count++;
336         }
337
338         SAFE_FREE(dbuf.dptr);
339         return message_count;
340 }
341
342 /****************************************************************************
343  Retrieve all messages for the current process.
344 ****************************************************************************/
345
346 static BOOL retrieve_all_messages(char **msgs_buf, size_t *total_len)
347 {
348         TDB_DATA kbuf;
349         TDB_DATA dbuf;
350         TDB_DATA null_dbuf;
351
352         ZERO_STRUCT(null_dbuf);
353
354         *msgs_buf = NULL;
355         *total_len = 0;
356
357         kbuf = message_key_pid(pid_to_procid(sys_getpid()));
358
359         if (tdb_chainlock(tdb, kbuf) == -1)
360                 return False;
361
362         dbuf = tdb_fetch(tdb, kbuf);
363         /*
364          * Replace with an empty record to keep the allocated
365          * space in the tdb.
366          */
367         tdb_store(tdb, kbuf, null_dbuf, TDB_REPLACE);
368         tdb_chainunlock(tdb, kbuf);
369
370         if (dbuf.dptr == NULL || dbuf.dsize == 0) {
371                 SAFE_FREE(dbuf.dptr);
372                 return False;
373         }
374
375         *msgs_buf = dbuf.dptr;
376         *total_len = dbuf.dsize;
377
378         return True;
379 }
380
381 /****************************************************************************
382  Parse out the next message for the current process.
383 ****************************************************************************/
384
385 static BOOL message_recv(char *msgs_buf, size_t total_len, int *msg_type,
386                          struct process_id *src, char **buf, size_t *len)
387 {
388         struct message_rec rec;
389         char *ret_buf = *buf;
390
391         *buf = NULL;
392         *len = 0;
393
394         if (total_len - (ret_buf - msgs_buf) < sizeof(rec))
395                 return False;
396
397         memcpy(&rec, ret_buf, sizeof(rec));
398         ret_buf += sizeof(rec);
399
400         if (rec.msg_version != MESSAGE_VERSION) {
401                 DEBUG(0,("message version %d received (expected %d)\n", rec.msg_version, MESSAGE_VERSION));
402                 return False;
403         }
404
405         if (rec.len > 0) {
406                 if (total_len - (ret_buf - msgs_buf) < rec.len)
407                         return False;
408         }
409
410         *len = rec.len;
411         *msg_type = rec.msg_type;
412         *src = rec.src;
413         *buf = ret_buf;
414
415         return True;
416 }
417
418 /****************************************************************************
419  Receive and dispatch any messages pending for this process.
420  Notice that all dispatch handlers for a particular msg_type get called,
421  so you can register multiple handlers for a message.
422  *NOTE*: Dispatch functions must be able to cope with incoming
423  messages on an *odd* byte boundary.
424 ****************************************************************************/
425
426 void message_dispatch(void)
427 {
428         int msg_type;
429         struct process_id src;
430         char *buf;
431         char *msgs_buf;
432         size_t len, total_len;
433         struct dispatch_fns *dfn;
434         int n_handled;
435
436         if (!received_signal)
437                 return;
438
439         DEBUG(10,("message_dispatch: received_signal = %d\n", received_signal));
440
441         received_signal = 0;
442
443         if (!retrieve_all_messages(&msgs_buf, &total_len))
444                 return;
445
446         for (buf = msgs_buf; message_recv(msgs_buf, total_len, &msg_type, &src, &buf, &len); buf += len) {
447                 DEBUG(10,("message_dispatch: received msg_type=%d "
448                           "src_pid=%u\n", msg_type,
449                           (unsigned int) procid_to_pid(&src)));
450                 n_handled = 0;
451                 for (dfn = dispatch_fns; dfn; dfn = dfn->next) {
452                         if (dfn->msg_type == msg_type) {
453                                 DEBUG(10,("message_dispatch: processing message of type %d.\n", msg_type));
454                                 dfn->fn(msg_type, src, len ? (void *)buf : NULL, len);
455                                 n_handled++;
456                         }
457                 }
458                 if (!n_handled) {
459                         DEBUG(5,("message_dispatch: warning: no handlers registed for "
460                                  "msg_type %d in pid %u\n",
461                                  msg_type, (unsigned int)sys_getpid()));
462                 }
463         }
464         SAFE_FREE(msgs_buf);
465 }
466
467 /****************************************************************************
468  Register a dispatch function for a particular message type.
469  *NOTE*: Dispatch functions must be able to cope with incoming
470  messages on an *odd* byte boundary.
471 ****************************************************************************/
472
473 void message_register(int msg_type, 
474                       void (*fn)(int msg_type, struct process_id pid,
475                                  void *buf, size_t len))
476 {
477         struct dispatch_fns *dfn;
478
479         dfn = SMB_MALLOC_P(struct dispatch_fns);
480
481         if (dfn != NULL) {
482
483                 ZERO_STRUCTPN(dfn);
484
485                 dfn->msg_type = msg_type;
486                 dfn->fn = fn;
487
488                 DLIST_ADD(dispatch_fns, dfn);
489         }
490         else {
491         
492                 DEBUG(0,("message_register: Not enough memory. malloc failed!\n"));
493         }
494 }
495
496 /****************************************************************************
497  De-register the function for a particular message type.
498 ****************************************************************************/
499
500 void message_deregister(int msg_type)
501 {
502         struct dispatch_fns *dfn, *next;
503
504         for (dfn = dispatch_fns; dfn; dfn = next) {
505                 next = dfn->next;
506                 if (dfn->msg_type == msg_type) {
507                         DLIST_REMOVE(dispatch_fns, dfn);
508                         SAFE_FREE(dfn);
509                 }
510         }       
511 }
512
513 struct msg_all {
514         int msg_type;
515         uint32 msg_flag;
516         const void *buf;
517         size_t len;
518         BOOL duplicates;
519         int n_sent;
520 };
521
522 /****************************************************************************
523  Send one of the messages for the broadcast.
524 ****************************************************************************/
525
526 static int traverse_fn(TDB_CONTEXT *the_tdb, TDB_DATA kbuf, TDB_DATA dbuf, void *state)
527 {
528         struct connections_data crec;
529         struct msg_all *msg_all = (struct msg_all *)state;
530
531         if (dbuf.dsize != sizeof(crec))
532                 return 0;
533
534         memcpy(&crec, dbuf.dptr, sizeof(crec));
535
536         if (crec.cnum != -1)
537                 return 0;
538
539         /* Don't send if the receiver hasn't registered an interest. */
540
541         if(!(crec.bcast_msg_flags & msg_all->msg_flag))
542                 return 0;
543
544         /* If the msg send fails because the pid was not found (i.e. smbd died), 
545          * the msg has already been deleted from the messages.tdb.*/
546
547         if (!message_send_pid(crec.pid, msg_all->msg_type,
548                               msg_all->buf, msg_all->len,
549                               msg_all->duplicates)) {
550                 
551                 /* If the pid was not found delete the entry from connections.tdb */
552
553                 if (errno == ESRCH) {
554                         DEBUG(2,("pid %s doesn't exist - deleting connections %d [%s]\n",
555                                  procid_str_static(&crec.pid),
556                                  crec.cnum, crec.name));
557                         tdb_delete(the_tdb, kbuf);
558                 }
559         }
560         msg_all->n_sent++;
561         return 0;
562 }
563
564 /**
565  * Send a message to all smbd processes.
566  *
567  * It isn't very efficient, but should be OK for the sorts of
568  * applications that use it. When we need efficient broadcast we can add
569  * it.
570  *
571  * @param n_sent Set to the number of messages sent.  This should be
572  * equal to the number of processes, but be careful for races.
573  *
574  * @retval True for success.
575  **/
576 BOOL message_send_all(TDB_CONTEXT *conn_tdb, int msg_type,
577                       const void *buf, size_t len,
578                       BOOL duplicates_allowed,
579                       int *n_sent)
580 {
581         struct msg_all msg_all;
582
583         msg_all.msg_type = msg_type;
584         if (msg_type < 1000)
585                 msg_all.msg_flag = FLAG_MSG_GENERAL;
586         else if (msg_type > 1000 && msg_type < 2000)
587                 msg_all.msg_flag = FLAG_MSG_NMBD;
588         else if (msg_type > 2000 && msg_type < 2100)
589                 msg_all.msg_flag = FLAG_MSG_PRINT_NOTIFY;
590         else if (msg_type > 2100 && msg_type < 3000)
591                 msg_all.msg_flag = FLAG_MSG_PRINT_GENERAL;
592         else if (msg_type > 3000 && msg_type < 4000)
593                 msg_all.msg_flag = FLAG_MSG_SMBD;
594         else
595                 return False;
596
597         msg_all.buf = buf;
598         msg_all.len = len;
599         msg_all.duplicates = duplicates_allowed;
600         msg_all.n_sent = 0;
601
602         tdb_traverse(conn_tdb, traverse_fn, &msg_all);
603         if (n_sent)
604                 *n_sent = msg_all.n_sent;
605         return True;
606 }
607 /** @} **/