import socket import asyncore from xdrlib2 import * from sunrpc import * BUFSIZE = 8192 _servers = {} # (program, version) -> (programDef, server) def exportService(programDef, server): program = programDef.program version = programDef.version _servers[(program, version)] = (programDef, server) class UDPServer(asyncore.dispatcher): def __init__(self, port, host=''): self.outbuf = [] sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) asyncore.dispatcher.__init__(self, sock) self.bind((host, port)) def handle_read_event (self): try: data, addr = self.socket.recvfrom(BUFSIZE) except: import traceback traceback.print_exc() return self.handle_readfrom(data, addr) def writable (self): return len(self.outbuf) def sendto (self, data, addr): self.outbuf.append((data, addr)) self.initiate_send() def initiate_send(self): b = self.outbuf while len(b): data, addr = b[0] del b[0] try: result = self.socket.sendto (data, addr) if result != len(data): self.log('Sent packet truncated to %d bytes' % result) except socket.error, why: if why[0] == EWOULDBLOCK: return else: raise socket.error, why def handle_write_event (self): self.initiate_send() def handle_error (self, *info): import traceback traceback.print_exc() print 'uncaptured python exception, closing channel %s' % `self` self.close() def handle_readfrom (self, data, addr): self.log('Unhandled readfrom(): got %d bytes from %s' % (len(data), addr)) class RPCError (Exception): '''RPC Error''' class AuthError (RPCError): '''Authorization error''' def __init__(self, stat): self.stat = stat class ProcedureNotFoundError (RPCError): '''Procedure not found''' class AuthInfo: def __init__(self, addr): self.hostname, self.port = addr from time import time class Stopwatch: def __init__(self): self.data = [] def begin(self): self.record = [time()] def mark(self): self.record.append(time()) def end(self): self.record.append(time()) self.data.append(self.record) def report(self): total = [] count = [] for record in self.data: idx = 0 while idx < len(record) - 1: t = record[idx + 1] - record[idx] if idx >= len(total): total.append(t) else: total[idx] = total[idx] + t if idx >= len(count): count.append(1) else: count[idx] = count[idx] + 1 idx = idx + 1 return map(lambda t, c: t * 1000 / c, total, count), total, count class CrudeStopwatch (Stopwatch): def mark(self): pass _stopwatch = CrudeStopwatch() class UDP_RPCServer(UDPServer): _xid_cache_len = 100 def __init__(self, port, host=''): UDPServer.__init__(self, port, host) self.xids = [] self.xid_replies = {} def handle_readfrom (self, data, addr): # Handle a sunrpc event. _stopwatch.begin() msg = createBlankInstance(RpcMsg) u = XDRUnpacker(data) msg.__xdr_unpack__(msg, u) body = msg.body xid = msg.xid try: if body.rpcvers != 2: self.sendRejectedReply(xid, RejectStat.RPC_MISMATCH, 2, addr) key = (body.prog, body.vers) if _servers.has_key(key): progDef, server = _servers[key] if progDef.procs.has_key(body.proc): procDef = progDef.procs[body.proc] # print body.prog, procDef[0] needAuth = procDef[-1] if needAuth: # Perform authentication. pass authinfo = AuthInfo(addr) res = self.doProcedure(procDef, u, body, server, authinfo) self.sendReply(xid, res, addr) else: raise ProcedureNotFoundError else: self.sendProgramMismatch(xid, body.prog, addr) except XDRError: self.sendAcceptError(xid, AcceptStat.GARBAGE_ARGS, addr) except ProcedureNotFoundError: self.sendAcceptError(xid, AcceptStat.PROC_UNAVAIL, addr) except AuthError, e: self.sendRejectedReply(xid, RejectStat.AUTH_ERROR, e.stat, addr) except: import traceback traceback.print_exc() self.sendAcceptError(xid, AcceptStat.SYSTEM_ERR, addr) _stopwatch.end() def sendAcceptError(self, xid, stat, addr, low_version=0, high_version=0): verf = OpaqueAuth(flavor=AuthFlavor.AUTH_NONE, body='') if high_version > 0: ar = AcceptedReply(verf=verf, accept_stat=stat, mismatch_low=low_version, mismatch_high=high_version) else: ar = AcceptedReply(verf=verf, accept_stat=stat) reply = RpcMsg(xid=xid, body=ReplyBody(reply=ar)) #print reply self.sendto(toXDR(reply), addr) def sendRejectedReply(self, xid, reject_stat, info, addr): if reject_stat == RejectStat.RPC_MISMATCH: rr = RejectedReply(reject_stat=reject_stat, mismatch_low=info, mismatch_high=info) else: rr = RejectedReply(reject_stat=reject_stat, auth_stat=info) reply = RpcMsg(xid=xid, body=ReplyBody(reply=rr)) self.sendto(toXDR(reply), addr) def sendProgramMismatch(self, xid, requested_prog, addr): # Is it the wrong version of an exported program? low_version = 0 high_version = 0 for prog, vers in _servers.keys(): if prog == requested_prog: if low_version <= 0 or vers < low_version: low_version = vers if vers > high_version: high_version = vers if high_version: self.sendAcceptError(xid, AcceptStat.PROG_MISMATCH, addr, low_version, high_version) else: self.sendAcceptError(xid, AcceptStat.PROG_UNAVAIL, addr) def doProcedure(self, procDef, u, body, server, authinfo): procName, argsClass, resClass, needAuth = procDef if argsClass: args = createBlankInstance(argsClass) args.__xdr_unpack__(args, u) else: args = None u.done() #print args ; print _stopwatch.mark() res = server._callProcedure(procName, args, resClass, authinfo) _stopwatch.mark() #print res ; print ; print return res def sendReply(self, xid, res, addr): verf = OpaqueAuth(flavor=AuthFlavor.AUTH_NONE, body='') ar = AcceptedReply(verf=verf, accept_stat=AcceptStat.SUCCESS) reply = RpcMsg(xid=xid, body=ReplyBody(reply=ar)) p = XDRPacker() reply.__xdr_pack__(reply, p) if res: res.__xdr_pack__(res, p) _stopwatch.mark() self.sendto(p.get_buffer(), addr) class ProcedureServer: def _callProcedure(self, procName, args, resClass, authinfo): m = getattr(self, procName) if args: res = apply(m, (authinfo,), args.__dict__) else: res = apply(m, (authinfo,)) if not resClass: if res: raise TypeError('Expected no result, got %s' % repr(res)) elif res.__class__ != resClass: raise TypeError('Expected a result of type %s, ' 'got %s' % (resClass, repr(res))) return res def test(): server = RPCServer(5489) asyncore.loop() if __name__ == '__main__': test()