import struct, socket, sys
class Error(Exception): pass
class buffsock:
"Buffered socket wrapper; always returns the amount of data you want."
def __init__(self, sock): self.sock = sock
def recv(self, nbytes):
rv = ''
while len(rv) < nbytes:
more = self.sock.recv(nbytes - len(rv))
if more == '': raise Error(nbytes)
rv += more
return rv
def send(self, astring): self.sock.send(astring)
def close(self): self.sock.close()
class debugsock:
"Debugging socket wrapper."
def __init__(self, sock): self.sock = sock
def recv(self, nbytes):
print "recv(%d) =" % nbytes,
rv = self.sock.recv(nbytes)
print `rv`
return rv
def send(self, astring):
print "send(%r) =" % astring,
rv = self.sock.send(astring)
print `rv`
return rv
def close(self):
print "close()"
self.sock.close()
def negotiation(exportsize):
"Returns initial NBD negotiation sequence for exportsize in bytes."
return ('NBDMAGIC' + '\x00\x00\x42\x02\x81\x86\x12\x53' +
struct.pack('>Q', exportsize) + '\0' * 128);
def nbd_reply(error=0, handle=1, data=''):
"Construct an NBD reply."
assert type(handle) is type('') and len(handle) == 8
return ('\x67\x44\x66\x98' + struct.pack('>L', error) + handle + data)
read_request = 0
write_request = 1
disconnect_request = 2
class nbd_request:
"Decodes an NBD request off the TCP socket."
def __init__(self, conn):
conn = buffsock(conn)
template = '>LL8sQL'
header = conn.recv(struct.calcsize(template))
(self.magic, self.type, self.handle, self.offset,
self.len) = struct.unpack(template, header)
if self.magic != 0x25609513: raise Error(self.magic)
if self.type == write_request:
self.data = conn.recv(self.len)
assert len(self.data) == self.len
def reply(self, error, data=''):
return nbd_reply(error=error, handle=self.handle, data=data)
def range(self):
return slice(self.offset, self.offset + self.len)
def serveclient(asock, afile):
"Serves a single client until it exits."
afile.seek(0)
abuf = list(afile.read())
asock.send(negotiation(len(abuf)))
while 1:
req = nbd_request(asock)
if req.type == read_request:
asock.send(req.reply(error=0,
data=''.join(abuf[req.range()])))
elif req.type == write_request:
abuf[req.range()] = req.data
afile.seek(req.offset)
afile.write(req.data)
afile.flush()
asock.send(req.reply(error=0))
elif req.type == disconnect_request:
asock.close()
return
def mainloop(listensock, afile):
"Serves clients forever."
while 1:
(sock, addr) = listensock.accept()
print "got conn on", addr
serveclient(sock, afile)
def main(argv):
"Given a port and a filename, serves up the file."
afile = file(argv[2], 'rb+')
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(('', int(argv[1])))
sock.listen(5)
mainloop(sock, afile)
if __name__ == '__main__': main(sys.argv)