diff --git a/lib/tdb/pytdb.c b/lib/tdb/pytdb.c index f216f4c7b90..d6d0625686f 100644 --- a/lib/tdb/pytdb.c +++ b/lib/tdb/pytdb.c @@ -260,6 +260,58 @@ static PyObject *obj_store(PyTdbObject *self, PyObject *args) return Py_None; } + +typedef struct { + PyObject_HEAD + TDB_DATA current; + PyTdbObject *iteratee; +} PyTdbIteratorObject; + +static PyObject *tdb_iter_next(PyTdbIteratorObject *self) +{ + TDB_DATA current; + PyObject *ret; + if (self->current.dptr == NULL && self->current.dsize == 0) + return NULL; + current = self->current; + self->current = tdb_nextkey(self->iteratee->ctx, self->current); + ret = PyString_FromTDB_DATA(current); + return ret; +} + +static void tdb_iter_dealloc(PyTdbIteratorObject *self) +{ + Py_DECREF(self->iteratee); + PyObject_Del(self); +} + +PyTypeObject PyTdbIterator = { + .tp_name = "Iterator", + .tp_basicsize = sizeof(PyTdbIteratorObject), + .tp_iternext = (iternextfunc)tdb_iter_next, + .tp_dealloc = (destructor)tdb_iter_dealloc, + .tp_flags = Py_TPFLAGS_DEFAULT, + .tp_iter = PyObject_SelfIter, +}; + +static PyObject *tdb_object_iter(PyTdbObject *self) +{ + PyTdbIteratorObject *ret; + + ret = PyObject_New(PyTdbIteratorObject, &PyTdbIterator); + ret->current = tdb_firstkey(self->ctx); + ret->iteratee = self; + Py_INCREF(self); + return (PyObject *)ret; +} + +static PyObject *obj_clear(PyTdbObject *self) +{ + int ret = tdb_wipe_all(self->ctx); + PyErr_TDB_ERROR_IS_ERR_RAISE(ret, self->ctx); + return Py_None; +} + static PyMethodDef tdb_object_methods[] = { { "transaction_cancel", (PyCFunction)obj_transaction_cancel, METH_NOARGS, "S.transaction_cancel() -> None\n" @@ -293,6 +345,9 @@ static PyMethodDef tdb_object_methods[] = { "Check whether key exists in this database." }, { "store", (PyCFunction)obj_store, METH_VARARGS, "S.store(key, data, flag=REPLACE) -> None" "Store data." }, + { "iterkeys", (PyCFunction)tdb_object_iter, METH_NOARGS, "S.iterkeys() -> iterator" }, + { "clear", (PyCFunction)obj_clear, METH_NOARGS, "S.clear() -> None\n" + "Wipe the entire database." }, { NULL } }; @@ -400,7 +455,6 @@ static PyMappingMethods tdb_object_mapping = { .mp_subscript = (binaryfunc)obj_getitem, .mp_ass_subscript = (objobjargproc)obj_setitem, }; - PyTypeObject PyTdb = { .tp_name = "Tdb", .tp_basicsize = sizeof(PyTdbObject), @@ -411,7 +465,8 @@ PyTypeObject PyTdb = { .tp_repr = (reprfunc)tdb_object_repr, .tp_dealloc = (destructor)tdb_object_dealloc, .tp_as_mapping = &tdb_object_mapping, - .tp_flags = Py_TPFLAGS_DEFAULT|Py_TPFLAGS_BASETYPE, + .tp_flags = Py_TPFLAGS_DEFAULT|Py_TPFLAGS_BASETYPE|Py_TPFLAGS_HAVE_ITER, + .tp_iter = (getiterfunc)tdb_object_iter, }; static PyMethodDef tdb_methods[] = { @@ -427,6 +482,9 @@ void inittdb(void) if (PyType_Ready(&PyTdb) < 0) return; + if (PyType_Ready(&PyTdbIterator) < 0) + return; + m = Py_InitModule3("tdb", tdb_methods, "TDB is a simple key-value database similar to GDBM that supports multiple writers."); if (m == NULL) return; @@ -446,4 +504,6 @@ void inittdb(void) Py_INCREF(&PyTdb); PyModule_AddObject(m, "Tdb", (PyObject *)&PyTdb); + + Py_INCREF(&PyTdbIterator); } diff --git a/lib/tdb/python/tests/simple.py b/lib/tdb/python/tests/simple.py index c6999d8c7f9..d242e665beb 100644 --- a/lib/tdb/python/tests/simple.py +++ b/lib/tdb/python/tests/simple.py @@ -80,18 +80,6 @@ class SimpleTdbTests(TestCase): self.tdb["brainslug"] = "2" self.assertEquals(["bla", "brainslug"], list(self.tdb)) - def test_items(self): - self.tdb["bla"] = "1" - self.tdb["brainslug"] = "2" - self.assertEquals([("bla", "1"), ("brainslug", "2")], self.tdb.items()) - - def test_iteritems(self): - self.tdb["bloe"] = "2" - self.tdb["bla"] = "25" - i = self.tdb.iteritems() - self.assertEquals(set([("bla", "25"), ("bloe", "2")]), - set([i.next(), i.next()])) - def test_transaction_cancel(self): self.tdb["bloe"] = "2" self.tdb.transaction_start() @@ -112,39 +100,23 @@ class SimpleTdbTests(TestCase): i = iter(self.tdb) self.assertEquals(set(["bloe", "bla"]), set([i.next(), i.next()])) - def test_keys(self): - self.tdb["bloe"] = "2" - self.tdb["bla"] = "25" - self.assertEquals(["bla", "bloe"], self.tdb.keys()) - def test_iterkeys(self): self.tdb["bloe"] = "2" self.tdb["bla"] = "25" i = self.tdb.iterkeys() self.assertEquals(set(["bloe", "bla"]), set([i.next(), i.next()])) - def test_values(self): - self.tdb["bloe"] = "2" - self.tdb["bla"] = "25" - self.assertEquals(["25", "2"], self.tdb.values()) - - def test_itervalues(self): - self.tdb["bloe"] = "2" - self.tdb["bla"] = "25" - i = self.tdb.itervalues() - self.assertEquals(set(["25", "2"]), set([i.next(), i.next()])) - def test_clear(self): self.tdb["bloe"] = "2" self.tdb["bla"] = "25" - self.assertEquals(2, len(self.tdb)) + self.assertEquals(2, len(list(self.tdb))) self.tdb.clear() - self.assertEquals(0, len(self.tdb)) + self.assertEquals(0, len(list(self.tdb))) def test_len(self): - self.assertEquals(0, len(self.tdb)) + self.assertEquals(0, len(list(self.tdb))) self.tdb["entry"] = "value" - self.assertEquals(1, len(self.tdb)) + self.assertEquals(1, len(list(self.tdb))) if __name__ == '__main__':