Mailing List Archive

[PATCH 4/5] Add support for ZSTD compression
From: Sebastian Andrzej Siewior <sebastian@breakpoint.cc>

The "zstd@openssh.com" compression algorithm enables ZSTD based
compression as defined in RFC8478. The compression is delayed until the
server sends the SSH_MSG_USERAUTH_SUCCESS which is the same time as with
"zlib@openssh.com" method.

Signed-off-by: Sebastian Andrzej Siewior <sebastian@breakpoint.cc>
---
cipher.c | 30 +++++-
configure.ac | 8 ++
kex.c | 5 +
kex.h | 3 +
myproposal.h | 2 +-
packet.c | 273 +++++++++++++++++++++++++++++++++++++++++++++------
readconf.c | 8 +-
servconf.c | 14 +--
ssh.c | 4 +-
9 files changed, 301 insertions(+), 46 deletions(-)

diff --git a/cipher.c b/cipher.c
index 8195199b32a2d..e1d0b5637a603 100644
--- a/cipher.c
+++ b/cipher.c
@@ -48,6 +48,7 @@
#include "sshbuf.h"
#include "ssherr.h"
#include "digest.h"
+#include "kex.h"

#include "openbsd-compat/openssl-compat.h"

@@ -146,12 +147,33 @@ cipher_alg_list(char sep, int auth_only)
const char *
compression_alg_list(int compression)
{
-#ifdef WITH_ZLIB
- return compression ? "zlib@openssh.com,zlib,none" :
- "none,zlib@openssh.com,zlib";
+#ifdef HAVE_LIBZSTD
+#define COMP_ZSTD_WITH "zstd@openssh.com,"
+#define COMP_ZSTD_NONE ",zstd@openssh.com"
#else
- return "none";
+#define COMP_ZSTD_WITH ""
+#define COMP_ZSTD_NONE ""
#endif
+
+#ifdef WITH_ZLIB
+#define COMP_ZLIB_C_WITH "zlib@openssh.com,zlib,"
+#define COMP_ZLIB_S_WITH "zlib@openssh.com,"
+
+#define COMP_ZLIB_C_NONE ",zlib@openssh.com,zlib"
+#else
+#define COMP_ZLIB_C_WITH ""
+#define COMP_ZLIB_S_WITH ""
+#define COMP_ZLIB_C_NONE ""
+#endif
+ switch (compression) {
+ case COMP_ZLIB: return COMP_ZLIB_C_WITH "none";
+ case COMP_DELAYED: return COMP_ZLIB_S_WITH "none";
+ case COMP_ZSTD: return COMP_ZSTD_WITH "none";
+ case COMP_ALL_C: return COMP_ZSTD_WITH COMP_ZLIB_C_WITH "none";
+ case COMP_ALL_S: return COMP_ZSTD_WITH COMP_ZLIB_S_WITH "none";
+ default:
+ case 0: return "none" COMP_ZSTD_NONE COMP_ZLIB_C_NONE;
+ }
}

u_int
diff --git a/configure.ac b/configure.ac
index 28947a6608455..88645473f2d51 100644
--- a/configure.ac
+++ b/configure.ac
@@ -1403,6 +1403,14 @@ See http://www.gzip.org/zlib/ for details.])
)
fi

+AC_ARG_WITH([libzstd], AS_HELP_STRING([--with-libzstd], [Build with libzstd.]))
+AS_IF([test "x$with_libzstd" = "xyes"],
+ [
+ PKG_CHECK_MODULES([LIBZSTD], [libzstd >= 1.4.0], [AC_DEFINE([HAVE_LIBZSTD], [1], [Use LIBZSTD])])
+ LIBS="$LIBS ${LIBZSTD_LIBS}"
+ CFLAGS="$CFLAGS ${LIBZSTD_CFLAGS}"
+ ])
+
dnl UnixWare 2.x
AC_CHECK_FUNC([strcasecmp],
[], [ AC_CHECK_LIB([resolv], [strcasecmp], [LIBS="$LIBS -lresolv"]) ]
diff --git a/kex.c b/kex.c
index aecb9394d8053..c2b51eb0982ab 100644
--- a/kex.c
+++ b/kex.c
@@ -804,6 +804,11 @@ choose_comp(struct sshcomp *comp, char *client, char *server)
comp->type = COMP_ZLIB;
} else
#endif /* WITH_ZLIB */
+#ifdef HAVE_LIBZSTD
+ if (strcmp(name, "zstd@openssh.com") == 0) {
+ comp->type = COMP_ZSTD;
+ } else
+#endif /* HAVE_LIBZSTD */
if (strcmp(name, "none") == 0) {
comp->type = COMP_NONE;
} else {
diff --git a/kex.h b/kex.h
index a5ae6ac050a78..5efe146d796c6 100644
--- a/kex.h
+++ b/kex.h
@@ -68,6 +68,9 @@
/* pre-auth compression (COMP_ZLIB) is only supported in the client */
#define COMP_ZLIB 1
#define COMP_DELAYED 2
+#define COMP_ZSTD 3
+#define COMP_ALL_C 4
+#define COMP_ALL_S 5

#define CURVE25519_SIZE 32

diff --git a/myproposal.h b/myproposal.h
index 5312e60581ced..4840bef213584 100644
--- a/myproposal.h
+++ b/myproposal.h
@@ -89,7 +89,7 @@
"rsa-sha2-512," \
"rsa-sha2-256"

-#define KEX_DEFAULT_COMP "none,zlib@openssh.com"
+#define KEX_DEFAULT_COMP "none,zstd@openssh.com,zlib@openssh.com"
#define KEX_DEFAULT_LANG ""

#define KEX_CLIENT \
diff --git a/packet.c b/packet.c
index 00e3180cb0ab7..d49c9594cf60d 100644
--- a/packet.c
+++ b/packet.c
@@ -79,6 +79,9 @@
#ifdef WITH_ZLIB
#include <zlib.h>
#endif
+#ifdef HAVE_LIBZSTD
+#include <zstd.h>
+#endif

#include "xmalloc.h"
#include "compat.h"
@@ -156,6 +159,14 @@ struct session_state {
/* Incoming/outgoing compression dictionaries */
z_stream compression_in_stream;
z_stream compression_out_stream;
+#endif
+#ifdef HAVE_LIBZSTD
+ ZSTD_DCtx *compression_zstd_in_stream;
+ ZSTD_CCtx *compression_zstd_out_stream;
+ u_int64_t compress_zstd_in_raw;
+ u_int64_t compress_zstd_in_comp;
+ u_int64_t compress_zstd_out_raw;
+ u_int64_t compress_zstd_out_comp;
#endif
int compression_in_started;
int compression_out_started;
@@ -616,11 +627,11 @@ ssh_packet_close_internal(struct ssh *ssh, int do_close)
state->newkeys[mode] = NULL;
ssh_clear_newkeys(ssh, mode); /* next keys */
}
-#ifdef WITH_ZLIB
/* compression state is in shared mem, so we can only release it once */
if (do_close && state->compression_buffer) {
sshbuf_free(state->compression_buffer);
- if (state->compression_out_started) {
+#ifdef WITH_ZLIB
+ if (state->compression_out_started == COMP_ZLIB) {
z_streamp stream = &state->compression_out_stream;
debug("compress outgoing: "
"raw data %llu, compressed %llu, factor %.2f",
@@ -631,7 +642,7 @@ ssh_packet_close_internal(struct ssh *ssh, int do_close)
if (state->compression_out_failures == 0)
deflateEnd(stream);
}
- if (state->compression_in_started) {
+ if (state->compression_in_started == COMP_ZLIB) {
z_streamp stream = &state->compression_in_stream;
debug("compress incoming: "
"raw data %llu, compressed %llu, factor %.2f",
@@ -642,8 +653,28 @@ ssh_packet_close_internal(struct ssh *ssh, int do_close)
if (state->compression_in_failures == 0)
inflateEnd(stream);
}
+#endif /* WITH_ZLIB */
+#ifdef HAVE_LIBZSTD
+ if (state->compression_out_started == COMP_ZSTD) {
+ debug("compress outgoing: "
+ "raw data %llu, compressed %llu, factor %.2f",
+ (unsigned long long)state->compress_zstd_out_raw,
+ (unsigned long long)state->compress_zstd_out_comp,
+ state->compress_zstd_out_raw == 0 ? 0.0 :
+ (double) state->compress_zstd_out_comp /
+ state->compress_zstd_out_raw);
+ }
+ if (state->compression_in_started == COMP_ZSTD) {
+ debug("compress incoming: "
+ "raw data %llu, compressed %llu, factor %.2f",
+ (unsigned long long)state->compress_zstd_in_raw,
+ (unsigned long long)state->compress_zstd_in_comp,
+ state->compress_zstd_in_raw == 0 ? 0.0 :
+ (double) state->compress_zstd_in_comp /
+ state->compress_zstd_in_raw);
+ }
+#endif /* HAVE_LIBZSTD */
}
-#endif /* WITH_ZLIB */
cipher_free(state->send_context);
cipher_free(state->receive_context);
state->send_context = state->receive_context = NULL;
@@ -708,11 +739,11 @@ start_compression_out(struct ssh *ssh, int level)
if (level < 1 || level > 9)
return SSH_ERR_INVALID_ARGUMENT;
debug("Enabling compression at level %d.", level);
- if (ssh->state->compression_out_started == 1)
+ if (ssh->state->compression_out_started == COMP_ZLIB)
deflateEnd(&ssh->state->compression_out_stream);
switch (deflateInit(&ssh->state->compression_out_stream, level)) {
case Z_OK:
- ssh->state->compression_out_started = 1;
+ ssh->state->compression_out_started = COMP_ZLIB;
break;
case Z_MEM_ERROR:
return SSH_ERR_ALLOC_FAIL;
@@ -725,11 +756,11 @@ start_compression_out(struct ssh *ssh, int level)
static int
start_compression_in(struct ssh *ssh)
{
- if (ssh->state->compression_in_started == 1)
+ if (ssh->state->compression_in_started == COMP_ZLIB)
inflateEnd(&ssh->state->compression_in_stream);
switch (inflateInit(&ssh->state->compression_in_stream)) {
case Z_OK:
- ssh->state->compression_in_started = 1;
+ ssh->state->compression_in_started = COMP_ZLIB;
break;
case Z_MEM_ERROR:
return SSH_ERR_ALLOC_FAIL;
@@ -746,7 +777,7 @@ compress_buffer(struct ssh *ssh, struct sshbuf *in, struct sshbuf *out)
u_char buf[4096];
int r, status;

- if (ssh->state->compression_out_started != 1)
+ if (ssh->state->compression_out_started != COMP_ZLIB)
return SSH_ERR_INTERNAL_ERROR;

/* This case is not handled below. */
@@ -792,7 +823,7 @@ uncompress_buffer(struct ssh *ssh, struct sshbuf *in, struct sshbuf *out)
u_char buf[4096];
int r, status;

- if (ssh->state->compression_in_started != 1)
+ if (ssh->state->compression_in_started != COMP_ZLIB)
return SSH_ERR_INTERNAL_ERROR;

if ((ssh->state->compression_in_stream.next_in =
@@ -860,6 +891,143 @@ uncompress_buffer(struct ssh *ssh, struct sshbuf *in, struct sshbuf *out)
}
#endif /* WITH_ZLIB */

+#ifdef HAVE_LIBZSTD
+static int
+start_compression_zstd_out(struct ssh *ssh)
+{
+ debug("Enabling ZSTD compression.");
+ if (ssh->state->compression_out_started == COMP_ZSTD)
+ ZSTD_CCtx_reset(ssh->state->compression_zstd_out_stream, ZSTD_reset_session_only);
+ if (!ssh->state->compression_zstd_out_stream)
+ ssh->state->compression_zstd_out_stream = ZSTD_createCCtx();
+ if (!ssh->state->compression_zstd_out_stream)
+ return SSH_ERR_ALLOC_FAIL;
+ ssh->state->compression_out_started = COMP_ZSTD;
+ return 0;
+}
+
+static int
+start_compression_zstd_in(struct ssh *ssh)
+{
+ if (ssh->state->compression_in_started == COMP_ZSTD)
+ ZSTD_DCtx_reset(ssh->state->compression_zstd_in_stream, ZSTD_reset_session_only);
+ if (!ssh->state->compression_zstd_in_stream)
+ ssh->state->compression_zstd_in_stream = ZSTD_createDCtx();
+ if (!ssh->state->compression_zstd_in_stream)
+ return SSH_ERR_ALLOC_FAIL;
+
+ ssh->state->compression_in_started = COMP_ZSTD;
+ return 0;
+}
+
+static int
+compress_buffer_zstd(struct ssh *ssh, struct sshbuf *in, struct sshbuf *out)
+{
+ u_char buf[4096];
+ ZSTD_inBuffer in_buff;
+ ZSTD_outBuffer out_buff;
+ int r, comp;
+
+ if (ssh->state->compression_out_started != COMP_ZSTD)
+ return SSH_ERR_INTERNAL_ERROR;
+
+ if (sshbuf_len(in) == 0)
+ return 0;
+
+ in_buff.src = sshbuf_mutable_ptr(in);
+ if (!in_buff.src)
+ return SSH_ERR_INTERNAL_ERROR;
+ in_buff.size = sshbuf_len(in);
+ in_buff.pos = 0;
+
+ ssh->state->compress_zstd_out_raw += in_buff.size;
+ out_buff.dst = buf;
+ out_buff.size = sizeof(buf);
+
+ /*
+ * Consume input and immediatelly flush compressed data. It will loop
+ * multiple times if the output does not fit into the buffer
+ */
+ do {
+ out_buff.pos = 0;
+
+ comp = ZSTD_compressStream2(ssh->state->compression_zstd_out_stream,
+ &out_buff, &in_buff, ZSTD_e_flush);
+ if (ZSTD_isError(comp))
+ return SSH_ERR_ALLOC_FAIL;
+ /* Append compressed data to output_buffer. */
+ r = sshbuf_put(out, buf, out_buff.pos);
+ if (r != 0)
+ return r;
+ ssh->state->compress_zstd_out_comp += out_buff.pos;
+ } while (comp > 0);
+ return 0;
+}
+
+static int uncompress_buffer_zstd(struct ssh *ssh, struct sshbuf *in,
+ struct sshbuf *out)
+{
+ u_char buf[4096];
+ ZSTD_inBuffer in_buff;
+ ZSTD_outBuffer out_buff;
+ int r, decomp;
+
+ if (ssh->state->compression_in_started != COMP_ZSTD)
+ return SSH_ERR_INTERNAL_ERROR;
+
+ in_buff.src = sshbuf_mutable_ptr(in);
+ if (in_buff.src == NULL)
+ return SSH_ERR_INTERNAL_ERROR;
+ in_buff.size = sshbuf_len(in);
+ in_buff.pos = 0;
+ ssh->state->compress_zstd_in_comp += in_buff.size;
+ for (;;) {
+ /* Set up fixed-size output buffer. */
+ out_buff.dst = buf;
+ out_buff.size = sizeof(buf);
+ out_buff.pos = 0;
+
+ decomp = ZSTD_decompressStream(ssh->state->compression_zstd_in_stream,
+ &out_buff, &in_buff);
+ if (ZSTD_isError(decomp))
+ return SSH_ERR_INVALID_FORMAT;
+
+ r = sshbuf_put(out, buf, out_buff.pos);
+ if (r != 0)
+ return r;
+ ssh->state->compress_zstd_in_raw += out_buff.pos;
+ if (in_buff.size == in_buff.pos &&
+ out_buff.pos < sizeof(buf))
+ return 0;
+ }
+}
+#else /* HAVE_LIBZSTD */
+
+static int
+start_compression_zstd_out(struct ssh *ssh)
+{
+ return SSH_ERR_INTERNAL_ERROR;
+}
+
+static int
+start_compression_zstd_in(struct ssh *ssh)
+{
+ return SSH_ERR_INTERNAL_ERROR;
+}
+
+static int
+compress_buffer_zstd(struct ssh *ssh, struct sshbuf *in, struct sshbuf *out)
+{
+ return SSH_ERR_INTERNAL_ERROR;
+}
+
+static int
+uncompress_buffer_zstd(struct ssh *ssh, struct sshbuf *in, struct sshbuf *out)
+{
+ return SSH_ERR_INTERNAL_ERROR;
+}
+#endif /* HAVE_LIBZSTD */
+
void
ssh_clear_newkeys(struct ssh *ssh, int mode)
{
@@ -936,18 +1104,30 @@ ssh_set_newkeys(struct ssh *ssh, int mode)
explicit_bzero(enc->key, enc->key_len);
explicit_bzero(mac->key, mac->key_len); */
if ((comp->type == COMP_ZLIB ||
- (comp->type == COMP_DELAYED &&
+ ((comp->type == COMP_DELAYED ||
+ comp->type == COMP_ZSTD) &&
state->after_authentication)) && comp->enabled == 0) {
if ((r = ssh_packet_init_compression(ssh)) < 0)
return r;
- if (mode == MODE_OUT) {
- if ((r = start_compression_out(ssh, 6)) != 0)
- return r;
+ if (comp->type == COMP_ZSTD) {
+ if (mode == MODE_OUT) {
+ if ((r = start_compression_zstd_out(ssh)) != 0)
+ return r;
+ } else {
+ if ((r = start_compression_zstd_in(ssh)) != 0)
+ return r;
+ }
+ comp->enabled = COMP_ZSTD;
} else {
- if ((r = start_compression_in(ssh)) != 0)
- return r;
+ if (mode == MODE_OUT) {
+ if ((r = start_compression_out(ssh, 6)) != 0)
+ return r;
+ } else {
+ if ((r = start_compression_in(ssh)) != 0)
+ return r;
+ }
+ comp->enabled = COMP_ZLIB;
}
- comp->enabled = 1;
}
/*
* The 2^(blocksize*2) limit is too expensive for 3DES,
@@ -1025,6 +1205,7 @@ ssh_packet_enable_delayed_compress(struct ssh *ssh)
struct session_state *state = ssh->state;
struct sshcomp *comp = NULL;
int r, mode;
+ int type = 0;

/*
* Remember that we are past the authentication step, so rekeying
@@ -1036,17 +1217,33 @@ ssh_packet_enable_delayed_compress(struct ssh *ssh)
if (state->newkeys[mode] == NULL)
continue;
comp = &state->newkeys[mode]->comp;
- if (comp && !comp->enabled && comp->type == COMP_DELAYED) {
- if ((r = ssh_packet_init_compression(ssh)) != 0)
+ if (comp && !comp->enabled && comp->type)
+ type = comp->type;
+ if (type == COMP_DELAYED || type == COMP_ZSTD) {
+ if ((r = ssh_packet_init_compression(ssh)) != 0) {
return r;
- if (mode == MODE_OUT) {
- if ((r = start_compression_out(ssh, 6)) != 0)
- return r;
- } else {
- if ((r = start_compression_in(ssh)) != 0)
- return r;
}
- comp->enabled = 1;
+ if (type == COMP_DELAYED) {
+ if (mode == MODE_OUT) {
+ if ((r = start_compression_out(ssh, 6)) != 0)
+ return r;
+ } else {
+ if ((r = start_compression_in(ssh)) != 0)
+ return r;
+ }
+ comp->enabled = COMP_ZLIB;
+ } else if (type == COMP_ZSTD) {
+ if (mode == MODE_OUT) {
+ if ((r = start_compression_zstd_out(ssh)) != 0)
+ return r;
+ } else {
+ if ((r = start_compression_zstd_in(ssh)) != 0)
+ return r;
+ }
+ comp->enabled = COMP_ZSTD;
+ } else {
+ return SSH_ERR_INTERNAL_ERROR;
+ }
}
}
return 0;
@@ -1107,9 +1304,15 @@ ssh_packet_send2_wrapped(struct ssh *ssh)
if ((r = sshbuf_consume(state->outgoing_packet, 5)) != 0)
goto out;
sshbuf_reset(state->compression_buffer);
- if ((r = compress_buffer(ssh, state->outgoing_packet,
- state->compression_buffer)) != 0)
- goto out;
+ if (comp->enabled == COMP_ZSTD) {
+ if ((r = compress_buffer_zstd(ssh, state->outgoing_packet,
+ state->compression_buffer)) != 0)
+ goto out;
+ } else {
+ if ((r = compress_buffer(ssh, state->outgoing_packet,
+ state->compression_buffer)) != 0)
+ goto out;
+ }
sshbuf_reset(state->outgoing_packet);
if ((r = sshbuf_put(state->outgoing_packet,
"\0\0\0\0\0", 5)) != 0 ||
@@ -1668,9 +1871,15 @@ ssh_packet_read_poll2(struct ssh *ssh, u_char *typep, u_int32_t *seqnr_p)
sshbuf_len(state->incoming_packet)));
if (comp && comp->enabled) {
sshbuf_reset(state->compression_buffer);
- if ((r = uncompress_buffer(ssh, state->incoming_packet,
- state->compression_buffer)) != 0)
- goto out;
+ if (comp->enabled == COMP_ZSTD) {
+ if ((r = uncompress_buffer_zstd(ssh, state->incoming_packet,
+ state->compression_buffer)) != 0)
+ goto out;
+ } else {
+ if ((r = uncompress_buffer(ssh, state->incoming_packet,
+ state->compression_buffer)) != 0)
+ goto out;
+ }
sshbuf_reset(state->incoming_packet);
if ((r = sshbuf_putb(state->incoming_packet,
state->compression_buffer)) != 0)
diff --git a/readconf.c b/readconf.c
index 554efd7c9c027..46c2cde18f330 100644
--- a/readconf.c
+++ b/readconf.c
@@ -870,8 +870,14 @@ static const struct multistate multistate_canonicalizehostname[] = {
{ NULL, -1 }
};
static const struct multistate multistate_compression[] = {
+#if defined(WITH_ZLIB) || defined(HAVE_LIBZSTD)
+ { "yes", COMP_ALL_C },
+#endif
#ifdef WITH_ZLIB
- { "yes", COMP_ZLIB },
+ { "zlib", COMP_ZLIB },
+#endif
+#ifdef HAVE_LIBZSTD
+ { "zstd", COMP_ZSTD },
#endif
{ "no", COMP_NONE },
{ NULL, -1 }
diff --git a/servconf.c b/servconf.c
index f08e37477957a..104240317bf8d 100644
--- a/servconf.c
+++ b/servconf.c
@@ -393,11 +393,7 @@ fill_default_server_options(ServerOptions *options)
options->permit_user_env_allowlist = NULL;
}
if (options->compression == -1)
-#ifdef WITH_ZLIB
- options->compression = COMP_DELAYED;
-#else
- options->compression = COMP_NONE;
-#endif
+ options->compression = COMP_ALL_S;

if (options->rekey_limit == -1)
options->rekey_limit = 0;
@@ -1234,9 +1230,15 @@ static const struct multistate multistate_permitrootlogin[] = {
{ NULL, -1 }
};
static const struct multistate multistate_compression[] = {
+#if defined(WITH_ZLIB) || defined(HAVE_LIBZSTD)
+ { "yes", COMP_ALL_S },
+#endif
#ifdef WITH_ZLIB
- { "yes", COMP_DELAYED },
{ "delayed", COMP_DELAYED },
+ { "zlib", COMP_DELAYED },
+#endif
+#ifdef HAVE_LIBZSTD
+ { "zstd", COMP_ZSTD },
#endif
{ "no", COMP_NONE },
{ NULL, -1 }
diff --git a/ssh.c b/ssh.c
index 9c6a6278bb4a4..e1272d4fb49bd 100644
--- a/ssh.c
+++ b/ssh.c
@@ -1029,8 +1029,8 @@ main(int ac, char **av)
break;

case 'C':
-#ifdef WITH_ZLIB
- options.compression = 1;
+#if defined(HAVE_LIBZSTD) || defined(WITH_ZLIB)
+ options.compression = COMP_ALL_C;
#else
error("Compression not supported, disabling.");
#endif
--
2.28.0

_______________________________________________
openssh-unix-dev mailing list
openssh-unix-dev@mindrot.org
https://lists.mindrot.org/mailman/listinfo/openssh-unix-dev