[PATCH v2 2/2] webservices: Respect session dictionary size limits on receive dictionary.

Connor McAdams cmcadams at codeweavers.com
Wed Apr 20 17:58:40 CDT 2022


Signed-off-by: Connor McAdams <cmcadams at codeweavers.com>
---
 dlls/webservices/channel.c       | 25 +++++++++++++++++++++++--
 dlls/webservices/tests/channel.c | 17 +++++++++++++++--
 2 files changed, 38 insertions(+), 4 deletions(-)

diff --git a/dlls/webservices/channel.c b/dlls/webservices/channel.c
index 325e842635b..1be1b6572f3 100644
--- a/dlls/webservices/channel.c
+++ b/dlls/webservices/channel.c
@@ -328,7 +328,7 @@ static void reset_channel( struct channel *channel )
     channel->session_state = SESSION_STATE_UNINITIALIZED;
     clear_addr( &channel->addr );
     init_dict( &channel->dict_send, channel->dict_size );
-    init_dict( &channel->dict_recv, 0 );
+    init_dict( &channel->dict_recv, channel->dict_size );
     channel->msg           = NULL;
     channel->read_size     = 0;
     channel->send_size     = 0;
@@ -486,6 +486,7 @@ static HRESULT create_channel( WS_CHANNEL_TYPE type, WS_CHANNEL_BINDING binding,
         channel->encoding     = WS_ENCODING_XML_BINARY_SESSION_1;
         channel->dict_size    = 2048;
         channel->dict_send.str_bytes_max = channel->dict_size;
+        channel->dict_recv.str_bytes_max = channel->dict_size;
         break;
 
     case WS_UDP_CHANNEL_BINDING:
@@ -546,6 +547,7 @@ static HRESULT create_channel( WS_CHANNEL_TYPE type, WS_CHANNEL_BINDING binding,
 
             channel->dict_size = *(ULONG *)prop->value;
             channel->dict_send.str_bytes_max = channel->dict_size;
+            channel->dict_recv.str_bytes_max = channel->dict_size;
             break;
 
         default:
@@ -897,6 +899,7 @@ HRESULT WINAPI WsShutdownSessionChannel( WS_CHANNEL *handle, const WS_ASYNC_CONT
     if (channel->state != WS_CHANNEL_STATE_OPEN)
     {
         LeaveCriticalSection( &channel->cs );
+        if (channel->state == WS_CHANNEL_STATE_FAULTED) return WS_E_OBJECT_FAULTED;
         return WS_E_INVALID_OPERATION;
     }
 
@@ -1595,6 +1598,7 @@ HRESULT channel_send_message( WS_CHANNEL *handle, WS_MESSAGE *msg )
     if (channel->state != WS_CHANNEL_STATE_OPEN)
     {
         LeaveCriticalSection( &channel->cs );
+        if (channel->state == WS_CHANNEL_STATE_FAULTED) return WS_E_OBJECT_FAULTED;
         return WS_E_INVALID_OPERATION;
     }
 
@@ -1784,6 +1788,7 @@ HRESULT WINAPI WsSendMessage( WS_CHANNEL *handle, WS_MESSAGE *msg, const WS_MESS
     if (channel->state != WS_CHANNEL_STATE_OPEN)
     {
         LeaveCriticalSection( &channel->cs );
+        if (channel->state == WS_CHANNEL_STATE_FAULTED) return WS_E_OBJECT_FAULTED;
         return WS_E_INVALID_OPERATION;
     }
 
@@ -1830,6 +1835,7 @@ HRESULT WINAPI WsSendReplyMessage( WS_CHANNEL *handle, WS_MESSAGE *msg, const WS
     if (channel->state != WS_CHANNEL_STATE_OPEN)
     {
         LeaveCriticalSection( &channel->cs );
+        if (channel->state == WS_CHANNEL_STATE_FAULTED) return WS_E_OBJECT_FAULTED;
         return WS_E_INVALID_OPERATION;
     }
 
@@ -2196,6 +2202,11 @@ static HRESULT build_dict( const BYTE *buf, ULONG buflen, struct dictionary *dic
             init_dict( dict, 0 );
             return WS_E_INVALID_FORMAT;
         }
+        if ((size + dict->str_bytes + 1) > dict->str_bytes_max)
+        {
+            hr = WS_E_QUOTA_EXCEEDED;
+            goto error;
+        }
         buflen -= size;
         if (!(bytes = malloc( size )))
         {
@@ -2241,7 +2252,11 @@ static HRESULT receive_message_bytes_session( struct channel *channel )
     {
         ULONG size;
         if ((hr = build_dict( (const BYTE *)channel->read_buf, channel->read_size, &channel->dict_recv,
-                              &size )) != S_OK) return hr;
+                              &size )) != S_OK)
+        {
+            if (hr == WS_E_QUOTA_EXCEEDED) channel->state = WS_CHANNEL_STATE_FAULTED;
+            return hr;
+        }
         channel->read_size -= size;
         memmove( channel->read_buf, channel->read_buf + size, channel->read_size );
     }
@@ -2301,6 +2316,7 @@ HRESULT channel_receive_message( WS_CHANNEL *handle, WS_MESSAGE *msg )
     if (channel->state != WS_CHANNEL_STATE_OPEN)
     {
         LeaveCriticalSection( &channel->cs );
+        if (channel->state == WS_CHANNEL_STATE_FAULTED) return WS_E_OBJECT_FAULTED;
         return WS_E_INVALID_OPERATION;
     }
 
@@ -2441,6 +2457,7 @@ HRESULT WINAPI WsReceiveMessage( WS_CHANNEL *handle, WS_MESSAGE *msg, const WS_M
     if (channel->state != WS_CHANNEL_STATE_OPEN)
     {
         LeaveCriticalSection( &channel->cs );
+        if (channel->state == WS_CHANNEL_STATE_FAULTED) return WS_E_OBJECT_FAULTED;
         return WS_E_INVALID_OPERATION;
     }
 
@@ -2559,6 +2576,7 @@ HRESULT WINAPI WsRequestReply( WS_CHANNEL *handle, WS_MESSAGE *request, const WS
     if (channel->state != WS_CHANNEL_STATE_OPEN)
     {
         LeaveCriticalSection( &channel->cs );
+        if (channel->state == WS_CHANNEL_STATE_FAULTED) return WS_E_OBJECT_FAULTED;
         return WS_E_INVALID_OPERATION;
     }
 
@@ -2643,6 +2661,7 @@ HRESULT WINAPI WsReadMessageStart( WS_CHANNEL *handle, WS_MESSAGE *msg, const WS
     if (channel->state != WS_CHANNEL_STATE_OPEN)
     {
         LeaveCriticalSection( &channel->cs );
+        if (channel->state == WS_CHANNEL_STATE_FAULTED) return WS_E_OBJECT_FAULTED;
         return WS_E_INVALID_OPERATION;
     }
 
@@ -2795,6 +2814,7 @@ HRESULT WINAPI WsWriteMessageStart( WS_CHANNEL *handle, WS_MESSAGE *msg, const W
     if (channel->state != WS_CHANNEL_STATE_OPEN)
     {
         LeaveCriticalSection( &channel->cs );
+        if (channel->state == WS_CHANNEL_STATE_FAULTED) return WS_E_OBJECT_FAULTED;
         return WS_E_INVALID_OPERATION;
     }
 
@@ -2875,6 +2895,7 @@ HRESULT WINAPI WsWriteMessageEnd( WS_CHANNEL *handle, WS_MESSAGE *msg, const WS_
     if (channel->state != WS_CHANNEL_STATE_OPEN)
     {
         LeaveCriticalSection( &channel->cs );
+        if (channel->state == WS_CHANNEL_STATE_FAULTED) return WS_E_OBJECT_FAULTED;
         return WS_E_INVALID_OPERATION;
     }
 
diff --git a/dlls/webservices/tests/channel.c b/dlls/webservices/tests/channel.c
index d5d35904c1c..7e0a9becbc5 100644
--- a/dlls/webservices/tests/channel.c
+++ b/dlls/webservices/tests/channel.c
@@ -822,6 +822,7 @@ static void client_duplex_session_dict( const struct listener_info *info )
     WS_MESSAGE_DESCRIPTION desc;
     WS_ENDPOINT_ADDRESS addr;
     WS_CHANNEL_PROPERTY prop;
+    WS_CHANNEL_STATE state;
     int dict_str_cnt = 0;
     char elem_name[128];
     WS_CHANNEL *channel;
@@ -893,12 +894,24 @@ static void client_duplex_session_dict( const struct listener_info *info )
     local_name.bytes = (BYTE *)short_dict_str;
     hr = WsReceiveMessage( channel, msg, descs, 1, WS_RECEIVE_REQUIRED_MESSAGE, WS_READ_REQUIRED_VALUE,
                            NULL, &val, sizeof(val), NULL, NULL, NULL );
-    todo_wine ok( hr == WS_E_QUOTA_EXCEEDED, "got %#lx\n", hr);
+    ok( hr == WS_E_QUOTA_EXCEEDED, "got %#lx\n", hr);
+
+    state = 0xdeadbeef;
+    hr = WsGetChannelProperty( channel, WS_CHANNEL_PROPERTY_STATE, &state, sizeof(state), NULL );
+    ok( hr == S_OK, "got %#lx\n", hr );
+    ok( state == WS_CHANNEL_STATE_FAULTED, "got %u\n", state );
+
+    hr = WsReceiveMessage( channel, msg, descs, 1, WS_RECEIVE_REQUIRED_MESSAGE, WS_READ_REQUIRED_VALUE,
+                           NULL, &val, sizeof(val), NULL, NULL, NULL );
+    ok( hr == WS_E_OBJECT_FAULTED, "got %#lx\n", hr );
+
+    hr = WsSendMessage( channel, msg, &desc, WS_WRITE_REQUIRED_VALUE, &val, sizeof(val), NULL, NULL );
+    ok( hr == WS_E_OBJECT_FAULTED, "got %#lx\n", hr );
 
     WsFreeMessage( msg );
 
     hr = WsShutdownSessionChannel( channel, NULL, NULL );
-    todo_wine ok( hr == WS_E_OBJECT_FAULTED, "got %#lx\n", hr );
+    ok( hr == WS_E_OBJECT_FAULTED, "got %#lx\n", hr );
 
     hr = WsCloseChannel( channel, NULL, NULL );
     ok( hr == S_OK, "got %#lx\n", hr );
-- 
2.25.1




More information about the wine-devel mailing list