From 15431d4fd8de67a03af4ee9cc8e774e1750e9da2 Mon Sep 17 00:00:00 2001
From: David Oberhollenzer <david.oberhollenzer@sigma-star.at>
Date: Mon, 14 Sep 2020 20:25:10 +0200
Subject: Add zstd stream compressor implementation to libfstream

Signed-off-by: David Oberhollenzer <david.oberhollenzer@sigma-star.at>
---
 lib/fstream/Makemodule.am                   |  6 ++
 lib/fstream/compress/ostream_compressor.c   |  5 ++
 lib/fstream/compress/zstd.c                 | 92 +++++++++++++++++++++++++++++
 lib/fstream/compressor.c                    | 10 ++++
 lib/fstream/internal.h                      |  4 ++
 lib/fstream/uncompress/autodetect.c         |  1 +
 lib/fstream/uncompress/istream_compressor.c |  5 ++
 lib/fstream/uncompress/zstd.c               | 79 +++++++++++++++++++++++++
 8 files changed, 202 insertions(+)
 create mode 100644 lib/fstream/compress/zstd.c
 create mode 100644 lib/fstream/uncompress/zstd.c

(limited to 'lib')

diff --git a/lib/fstream/Makemodule.am b/lib/fstream/Makemodule.am
index a2e414e..8a1254c 100644
--- a/lib/fstream/Makemodule.am
+++ b/lib/fstream/Makemodule.am
@@ -29,4 +29,10 @@ libfstream_a_SOURCES += lib/fstream/uncompress/gzip.c
 libfstream_a_CPPFLAGS += -DWITH_GZIP
 endif
 
+if WITH_ZSTD
+libfstream_a_SOURCES += lib/fstream/compress/zstd.c
+libfstream_a_SOURCES += lib/fstream/uncompress/zstd.c
+libfstream_a_CPPFLAGS += -DWITH_ZSTD
+endif
+
 noinst_LIBRARIES += libfstream.a
diff --git a/lib/fstream/compress/ostream_compressor.c b/lib/fstream/compress/ostream_compressor.c
index 5137f1d..d1d55e1 100644
--- a/lib/fstream/compress/ostream_compressor.c
+++ b/lib/fstream/compress/ostream_compressor.c
@@ -75,6 +75,11 @@ ostream_t *ostream_compressor_create(ostream_t *strm, int comp_id)
 	case FSTREAM_COMPRESSOR_XZ:
 #ifdef WITH_XZ
 		comp = ostream_xz_create(strm->get_filename(strm));
+#endif
+		break;
+	case FSTREAM_COMPRESSOR_ZSTD:
+#ifdef WITH_ZSTD
+		comp = ostream_zstd_create(strm->get_filename(strm));
 #endif
 		break;
 	default:
diff --git a/lib/fstream/compress/zstd.c b/lib/fstream/compress/zstd.c
new file mode 100644
index 0000000..f4f7f86
--- /dev/null
+++ b/lib/fstream/compress/zstd.c
@@ -0,0 +1,92 @@
+/* SPDX-License-Identifier: GPL-3.0-or-later */
+/*
+ * zstd.c
+ *
+ * Copyright (C) 2019 David Oberhollenzer <goliath@infraroot.at>
+ */
+#include "../internal.h"
+
+#include <zstd.h>
+
+typedef struct {
+	ostream_comp_t base;
+
+	ZSTD_CStream *strm;
+} ostream_zstd_t;
+
+static int flush_inbuf(ostream_comp_t *base, bool finish)
+{
+	ostream_zstd_t *zstd = (ostream_zstd_t *)base;
+	ZSTD_EndDirective op;
+	ZSTD_outBuffer out;
+	ZSTD_inBuffer in;
+	size_t ret;
+
+	op = finish ? ZSTD_e_end : ZSTD_e_continue;
+
+	do {
+		memset(&in, 0, sizeof(in));
+		memset(&out, 0, sizeof(out));
+
+		in.src = base->inbuf;
+		in.size = base->inbuf_used;
+
+		out.dst = base->outbuf;
+		out.size = BUFSZ;
+
+		ret = ZSTD_compressStream2(zstd->strm, &out, &in, op);
+
+		if (ZSTD_isError(ret)) {
+			fprintf(stderr, "%s: error in zstd compressor.\n",
+				base->wrapped->get_filename(base->wrapped));
+			return -1;
+		}
+
+		if (base->wrapped->append(base->wrapped, base->outbuf,
+					  out.pos)) {
+			return -1;
+		}
+
+		if (in.pos < in.size) {
+			base->inbuf_used = in.size - in.pos;
+
+			memmove(base->inbuf, base->inbuf + in.pos,
+				base->inbuf_used);
+		} else {
+			base->inbuf_used = 0;
+		}
+	} while (finish && ret != 0);
+
+	return 0;
+}
+
+static void cleanup(ostream_comp_t *base)
+{
+	ostream_zstd_t *zstd = (ostream_zstd_t *)base;
+
+	ZSTD_freeCStream(zstd->strm);
+}
+
+ostream_comp_t *ostream_zstd_create(const char *filename)
+{
+	ostream_zstd_t *zstd = calloc(1, sizeof(*zstd));
+	ostream_comp_t *base = (ostream_comp_t *)zstd;
+
+	if (zstd == NULL) {
+		fprintf(stderr, "%s: creating zstd wrapper: %s.\n",
+			filename, strerror(errno));
+		return NULL;
+	}
+
+	zstd->strm = ZSTD_createCStream();
+	if (zstd->strm == NULL) {
+		fprintf(stderr, "%s: error creating zstd decoder.\n",
+			filename);
+		free(zstd);
+		return NULL;
+	}
+
+	base->flush_inbuf = flush_inbuf;
+	base->cleanup = cleanup;
+	return base;
+}
diff --git a/lib/fstream/compressor.c b/lib/fstream/compressor.c
index b8f9c6b..84a859c 100644
--- a/lib/fstream/compressor.c
+++ b/lib/fstream/compressor.c
@@ -14,6 +14,9 @@ int fstream_compressor_id_from_name(const char *name)
 	if (strcmp(name, "xz") == 0)
 		return FSTREAM_COMPRESSOR_XZ;
 
+	if (strcmp(name, "zstd") == 0)
+		return FSTREAM_COMPRESSOR_ZSTD;
+
 	return -1;
 }
 
@@ -25,6 +28,9 @@ const char *fstream_compressor_name_from_id(int id)
 	if (id == FSTREAM_COMPRESSOR_XZ)
 		return "xz";
 
+	if (id == FSTREAM_COMPRESSOR_ZSTD)
+		return "zstd";
+
 	return NULL;
 }
 
@@ -38,6 +44,10 @@ bool fstream_compressor_exists(int id)
 #ifdef WITH_XZ
 	case FSTREAM_COMPRESSOR_XZ:
 		return true;
+#endif
+#ifdef WITH_ZSTD
+	case FSTREAM_COMPRESSOR_ZSTD:
+		return true;
 #endif
 	default:
 		break;
diff --git a/lib/fstream/internal.h b/lib/fstream/internal.h
index 160a523..83ecc64 100644
--- a/lib/fstream/internal.h
+++ b/lib/fstream/internal.h
@@ -57,10 +57,14 @@ SQFS_INTERNAL ostream_comp_t *ostream_gzip_create(const char *filename);
 
 SQFS_INTERNAL ostream_comp_t *ostream_xz_create(const char *filename);
 
+SQFS_INTERNAL ostream_comp_t *ostream_zstd_create(const char *filename);
+
 SQFS_INTERNAL istream_comp_t *istream_gzip_create(const char *filename);
 
 SQFS_INTERNAL istream_comp_t *istream_xz_create(const char *filename);
 
+SQFS_INTERNAL istream_comp_t *istream_zstd_create(const char *filename);
+
 #ifdef __cplusplus
 }
 #endif
diff --git a/lib/fstream/uncompress/autodetect.c b/lib/fstream/uncompress/autodetect.c
index 4ffe078..b788518 100644
--- a/lib/fstream/uncompress/autodetect.c
+++ b/lib/fstream/uncompress/autodetect.c
@@ -13,6 +13,7 @@ static const struct {
 } magic[] = {
 	{ FSTREAM_COMPRESSOR_GZIP, (const sqfs_u8 *)"\x1F\x8B\x08", 3 },
 	{ FSTREAM_COMPRESSOR_XZ, (const sqfs_u8 *)("\xFD" "7zXZ"), 6 },
+	{ FSTREAM_COMPRESSOR_ZSTD, (const sqfs_u8 *)"\x28\xB5\x2F\xFD", 4 },
 };
 
 int istream_detect_compressor(istream_t *strm,
diff --git a/lib/fstream/uncompress/istream_compressor.c b/lib/fstream/uncompress/istream_compressor.c
index 924f309..2262c9b 100644
--- a/lib/fstream/uncompress/istream_compressor.c
+++ b/lib/fstream/uncompress/istream_compressor.c
@@ -37,6 +37,11 @@ istream_t *istream_compressor_create(istream_t *strm, int comp_id)
 	case FSTREAM_COMPRESSOR_XZ:
 #ifdef WITH_XZ
 		comp = istream_xz_create(strm->get_filename(strm));
+#endif
+		break;
+	case FSTREAM_COMPRESSOR_ZSTD:
+#ifdef WITH_ZSTD
+		comp = istream_zstd_create(strm->get_filename(strm));
 #endif
 		break;
 	default:
diff --git a/lib/fstream/uncompress/zstd.c b/lib/fstream/uncompress/zstd.c
new file mode 100644
index 0000000..1838af5
--- /dev/null
+++ b/lib/fstream/uncompress/zstd.c
@@ -0,0 +1,79 @@
+/* SPDX-License-Identifier: GPL-3.0-or-later */
+/*
+ * zstd.c
+ *
+ * Copyright (C) 2019 David Oberhollenzer <goliath@infraroot.at>
+ */
+#include "../internal.h"
+
+#include <zstd.h>
+
+typedef struct {
+	istream_comp_t base;
+
+	ZSTD_DStream* strm;
+} istream_zstd_t;
+
+static int precache(istream_t *base)
+{
+	istream_zstd_t *zstd = (istream_zstd_t *)base;
+	istream_t *wrapped = ((istream_comp_t *)base)->wrapped;
+	ZSTD_outBuffer out;
+	ZSTD_inBuffer in;
+	size_t ret;
+
+	if (istream_precache(wrapped))
+		return -1;
+
+	memset(&in, 0, sizeof(in));
+	memset(&out, 0, sizeof(out));
+
+	in.src = wrapped->buffer;
+	in.size = wrapped->buffer_used;
+
+	out.dst = ((istream_comp_t *)base)->uncompressed + base->buffer_used;
+	out.size = BUFSZ - base->buffer_used;
+
+	ret = ZSTD_decompressStream(zstd->strm, &out, &in);
+
+	if (ZSTD_isError(ret)) {
+		fprintf(stderr, "%s: error in zstd decoder.\n",
+			wrapped->get_filename(wrapped));
+		return -1;
+	}
+
+	wrapped->buffer_offset = in.pos;
+	base->buffer_used += out.pos;
+	return 0;
+}
+
+static void cleanup(istream_comp_t *base)
+{
+	istream_zstd_t *zstd = (istream_zstd_t *)base;
+
+	ZSTD_freeDStream(zstd->strm);
+}
+
+istream_comp_t *istream_zstd_create(const char *filename)
+{
+	istream_zstd_t *zstd = calloc(1, sizeof(*zstd));
+	istream_comp_t *base = (istream_comp_t *)zstd;
+
+	if (zstd == NULL) {
+		fprintf(stderr, "%s: creating zstd decoder: %s.\n",
+			filename, strerror(errno));
+		return NULL;
+	}
+
+	zstd->strm = ZSTD_createDStream();
+	if (zstd->strm == NULL) {
+		fprintf(stderr, "%s: error creating zstd decoder.\n",
+			filename);
+		free(zstd);
+		return NULL;
+	}
+
+	((istream_t *)base)->precache = precache;
+	base->cleanup = cleanup;
+	return base;
+}
-- 
cgit v1.2.3