00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00032 #include <s3fc/s3_socket_tcp_ssl.h>
00033 #include <s3fc/s3_thread_base.h>
00034 #include <s3fc/s3_macros.h>
00035 #include <sstream>
00036 #ifndef _WIN32
00037 #include <sys/time.h>
00038 #endif
00039
00040 bool s3_socket_tcp_ssl::ssl_init_done = false;
00041
00042 static const unsigned long ERROR_CERT_IN_HASH = ERR_PACK(
00043 ERR_LIB_X509,X509_F_X509_STORE_ADD_CERT,X509_R_CERT_ALREADY_IN_HASH_TABLE);
00044
00045 s3_socket_tcp_ssl::s3_socket_tcp_ssl(int newsock, bool new_set_reuseaddr) :
00046 s3_socket_tcp(newsock, new_set_reuseaddr), ssl_error(ERR_LIB_NONE),
00047 ctx(0), ssl(0), keyfile(""), password(""), CAfile(""), CApath("")
00048 {
00049 if( !ssl_init_done )
00050 {
00051 SSL_load_error_strings();
00052 SSL_library_init();
00053
00054 #ifdef _WIN32
00055 RAND_screen();
00056 #endif
00057 ssl_init_done = true;
00058 }
00059 }
00060
00061 s3_socket_tcp_ssl::s3_socket_tcp_ssl(const s3_socket_tcp_ssl& template_sock) :
00062 s3_socket_tcp(template_sock), ssl_error(ERR_LIB_NONE), ctx(0), ssl(0),
00063 keyfile(template_sock.keyfile), password(template_sock.password),
00064 CAfile(template_sock.CAfile), CApath(template_sock.CApath)
00065 {
00066 }
00067
00068 s3_socket_tcp_ssl::~s3_socket_tcp_ssl()
00069 {
00070 close(false);
00071 if(ssl != 0)
00072 {
00073 SSL_free(ssl);
00074 ssl = 0;
00075 }
00076 if(ctx != 0)
00077 {
00078 SSL_CTX_free(ctx);
00079 ctx = 0;
00080 }
00081 }
00082
00083 bool s3_socket_tcp_ssl::set_keyfile_password(std::string newkey,
00084 std::string newpass)
00085 {
00086 keyfile = newkey;
00087 password = newpass;
00088 return true;
00089 }
00090
00091 bool s3_socket_tcp_ssl::set_verify_locations_file(std::string new_CAfile)
00092 {
00093 CAfile = new_CAfile;
00094 return true;
00095 }
00096
00097 bool s3_socket_tcp_ssl::set_verify_locations_path(std::string new_CApath)
00098 {
00099 CApath = new_CApath;
00100 return true;
00101 }
00102
00103 bool s3_socket_tcp_ssl::setup_ssl()
00104 {
00105 S3FC_DBG2_("s3_socket_tcp_ssl::setup_ssl()", "sock=" << sock);
00106
00107 set_errnos();
00108
00109 if( ctx == 0 && (ctx = SSL_CTX_new(SSLv3_method())) == 0 )
00110 {
00111 S3FC_DBG2_("s3_socket_tcp_ssl::setup_ssl()", "SSL_CTX_new failed");
00112 set_errnos(ERR_get_error());
00113 return false;
00114 }
00115
00116 if ( !(SSL_CTX_use_certificate_chain_file(ctx, keyfile.c_str())) )
00117 {
00118 unsigned long temperr = ERR_get_error();
00119 if ( temperr != ERROR_CERT_IN_HASH )
00120 {
00121 set_errnos(temperr);
00122 return false;
00123 }
00124 }
00125 SSL_CTX_set_default_passwd_cb_userdata(ctx, &password);
00126 SSL_CTX_set_default_passwd_cb(ctx, s3_socket_tcp_ssl::password_cb);
00127 if ( !(SSL_CTX_use_PrivateKey_file(
00128 ctx, keyfile.c_str(), SSL_FILETYPE_PEM)) )
00129 {
00130 S3FC_DBG_("Can't read key file: " << keyfile);
00131 set_errnos(ERR_get_error());
00132 return false;
00133 }
00134
00135 const char* CAfile_c = CAfile.size() > 0 ? CAfile.c_str() : 0;
00136 const char* CApath_c = CApath.size() > 0 ? CApath.c_str() : 0;
00137 if ( !(SSL_CTX_load_verify_locations(ctx, CAfile_c, CApath_c)) )
00138 {
00139 S3FC_DBG_("Can't read CA list");
00140 set_errnos(ERR_get_error());
00141 return false;
00142 }
00143 #if (OPENSSL_VERSION_NUMBER < 0x00905100L)
00144 SSL_CTX_set_verify_depth(ctx,1);
00145 #endif
00146
00147 #if 1
00148 if ( ssl != 0 )
00149 {
00150 SSL_free(ssl);
00151 ssl = 0;
00152 }
00153 #endif
00154
00155 if( (ssl = SSL_new(ctx)) == 0)
00156 {
00157 S3FC_DBG2_("s3_socket_tcp_ssl::setup_ssl()", "SSL_new failed");
00158 set_errnos(ERR_get_error());
00159 return false;
00160 }
00161 if( SSL_set_fd(ssl, sock) != 1 )
00162 {
00163 S3FC_DBG2_("s3_socket_tcp_ssl::setup_ssl()", "SSL_set_fd failed");
00164 set_errnos(ERR_get_error());
00165 return false;
00166 }
00167 return true;
00168 }
00169
00170 void s3_socket_tcp_ssl::set_errnos(unsigned long new_ssl_error)
00171 {
00172 ssl_error = new_ssl_error;
00173 s3_socket_tcp::set_errnos();
00174 }
00175
00176 bool s3_socket_tcp_ssl::accept(s3_socket_tcp_ssl& newsock,
00177 sockaddr_in* clientname)
00178 {
00179 S3FC_DBG_("s3_socket_tcp_ssl::accept()");
00180
00181 set_errnos();
00182
00183 if( !s3_socket_tcp::accept(newsock, clientname) )
00184 {
00185 return false;
00186 }
00187
00188 if ( !newsock.set_keyfile_password(keyfile, password) ||
00189 !newsock.set_verify_locations_file(CAfile) ||
00190 !newsock.set_verify_locations_path(CApath) ||
00191 !newsock.setup_ssl() )
00192 {
00193 return false;
00194 }
00195 SSL* ssl_accepted = newsock.ssl;
00196 assert( ssl_accepted != 0 );
00197 SSL_set_accept_state(ssl_accepted);
00198 int ret;
00199
00200 if( (ret = SSL_accept(ssl_accepted)) != 1 )
00201 {
00202 S3FC_DBG2_("s3_socket_tcp_ssl::accept()",
00203 "SSL_accept failed, ret = " << ret);
00204 set_errnos(SSL_get_error(ssl, ret));
00205 return false;
00206 }
00207 return true;
00208 }
00209
00210 bool s3_socket_tcp_ssl::connect(const std::string& IP_address, in_port_t port)
00211 {
00212 S3FC_DBG_( "s3_socket_tcp_ssl::connect(IP_address=" << IP_address
00213 << ",port=" << port << ")" );
00214
00215 set_errnos();
00216 if( !setup_ssl() )
00217 {
00218 return false;
00219 }
00220 SSL_set_connect_state(ssl);
00221 if( !s3_socket_tcp::connect(IP_address, port) )
00222 {
00223 return false;
00224 }
00225
00226 int ret;
00227 if( (ret = SSL_connect(ssl)) != 1 )
00228 {
00229 set_errnos(SSL_get_error(ssl, ret));
00230 return false;
00231 }
00232 return true;
00233 }
00234
00235 bool s3_socket_tcp_ssl::close(bool reinit)
00236 {
00237 S3FC_DBG_("s3_socket_tcp_ssl::close(reinit=" << reinit << ")");
00238
00239 set_errnos();
00240 int ret;
00241
00242 if( ssl != 0 && (ret = SSL_shutdown(ssl)) != 1 )
00243 {
00244 set_errnos(SSL_get_error(ssl, ret));
00245 return false;
00246 }
00247 if( !s3_socket_tcp::close(reinit) )
00248 {
00249 return false;
00250 }
00251 return true;
00252 }
00253
00254 std::string s3_socket_tcp_ssl::get_error() const
00255 {
00256 S3FC_DBG2_( "s3_socket_tcp_ssl::get_error()",
00257 "ssl_error = " << ssl_error );
00258
00259 if(ssl_error != ERR_LIB_NONE)
00260 {
00261 return ERR_error_string(ssl_error, 0);
00262 }
00263 return s3_socket_tcp::get_error();
00264 }
00265
00266 bool s3_socket_tcp_ssl::read(void *data, int size, s3_semaphore *term)
00267 {
00268 S3FC_DBG_("s3_socket_tcp_ssl::read()");
00269
00270 set_errnos();
00271 char* buf = reinterpret_cast<char*>(data);
00272 int rem_size = size;
00273 do
00274 {
00275
00276 s3_thread_base::test_cancel();
00277
00278
00279 if (term != 0 && term->try_wait())
00280 {
00281 s3_socket_tcp::set_errnos(-2);
00282 return false;
00283 }
00284
00285 switch (select_rd(1) )
00286 {
00287 case -1 : return false;
00288 case 0 : continue;
00289 case 1 : break;
00290 default: throw s3_generic_exception("s3_socket_tcp_ssl::read()",
00291 "select_rd() returned an invalid value");
00292 }
00293
00294 int ret;
00295 assert( ssl != 0 );
00296 if( (ret = SSL_read(ssl, buf, size)) <= 0 )
00297 {
00298 set_errnos(SSL_get_error(ssl, ret));
00299 return false;
00300 }
00301 rem_size -= ret;
00302 buf = buf + ret;
00303 }
00304 while (rem_size > 0);
00305
00306 return true;
00307 }
00308
00309 bool s3_socket_tcp_ssl::write(void *data, int size, int max_packet,
00310 s3_semaphore *term)
00311 {
00312 S3FC_DBG_("s3_socket_tcp_ssl::write()");
00313
00314 set_errnos();
00315 char* buf = reinterpret_cast<char*>(data);
00316 int rem_size = size;
00317 int packet_size = size;
00318 if ( max_packet > size )
00319 {
00320 packet_size = max_packet;
00321 }
00322
00323 do
00324 {
00325 s3_thread_base::test_cancel();
00326
00327
00328 if (term != 0 && term->try_wait())
00329 {
00330 s3_socket_tcp::set_errnos(-2);
00331 return false;
00332 }
00333
00334 switch (select_wr(1) )
00335 {
00336 case -1 : return false;
00337 case 0 : continue;
00338 case 1 : break;
00339 default: throw s3_generic_exception("s3_socket_tcp_ssl::write()",
00340 "select_wr() returned an invalid value");
00341 }
00342
00343 int ret;
00344 if( (ret = SSL_write(ssl, buf, size)) <= 0 )
00345 {
00346 set_errnos(SSL_get_error(ssl, ret));
00347 return false;
00348 }
00349 rem_size = rem_size - ret;
00350 buf = buf + ret;
00351 if (rem_size < packet_size) packet_size = rem_size;
00352 }
00353 while (rem_size > 0);
00354
00355 return true;
00356 }
00357
00358 int s3_socket_tcp_ssl::password_cb(char *buf, int num, int rwflag,
00359 void *userdata)
00360 {
00361 std::string* pass = reinterpret_cast<std::string*>(userdata);
00362 if ( num < static_cast<int>(pass->size()) )
00363 {
00364 return 0;
00365 }
00366 pass->copy(buf, num);
00367 return pass->size();
00368 }