34
34
35
35
import elasticapm
36
36
from elasticapm .instrumentation .packages .base import AbstractInstrumentedModule
37
+ from elasticapm .traces import execution_context
37
38
from elasticapm .utils .logging import get_logger
38
39
39
40
logger = get_logger ("elasticapm.instrument" )
40
41
41
42
should_capture_body_re = re .compile ("/(_search|_msearch|_count|_async_search|_sql|_eql)(/|$)" )
42
43
43
44
44
- class ElasticSearchConnectionMixin ( object ):
45
- query_methods = ( "search" , "count" , "delete_by_query" )
45
+ class ElasticsearchConnectionInstrumentation ( AbstractInstrumentedModule ):
46
+ name = "elasticsearch_connection"
46
47
47
- def get_signature ( self , args , kwargs ):
48
- args_len = len ( args )
49
- http_method = args [ 0 ] if args_len else kwargs . get ( "method" )
50
- http_path = args [ 1 ] if args_len > 1 else kwargs . get ( "url" )
48
+ instrument_list = [
49
+ ( "elasticsearch.connection.http_urllib3" , "Urllib3HttpConnection.perform_request" ),
50
+ ( "elasticsearch.connection.http_requests" , "RequestsHttpConnection.perform_request" ),
51
+ ]
51
52
52
- return "ES %s %s" % (http_method , http_path )
53
+ def call (self , module , method , wrapped , instance , args , kwargs ):
54
+ span = execution_context .get_span ()
55
+
56
+ self ._update_context_by_request_data (span .context , instance , args , kwargs )
57
+
58
+ status_code , headers , raw_data = wrapped (* args , ** kwargs )
59
+
60
+ span .context ["http" ] = {"status_code" : status_code }
53
61
54
- def get_context (self , instance , args , kwargs ):
62
+ return status_code , headers , raw_data
63
+
64
+ def _update_context_by_request_data (self , context , instance , args , kwargs ):
55
65
args_len = len (args )
56
66
url = args [1 ] if args_len > 1 else kwargs .get ("url" )
57
67
params = args [2 ] if args_len > 2 else kwargs .get ("params" )
58
68
body_serialized = args [3 ] if args_len > 3 else kwargs .get ("body" )
59
69
60
70
should_capture_body = bool (should_capture_body_re .search (url ))
61
71
62
- context = { "db" : {"type" : "elasticsearch" } }
72
+ context [ "db" ] = {"type" : "elasticsearch" }
63
73
if should_capture_body :
64
74
query = []
65
75
# using both q AND body is allowed in some API endpoints / ES versions,
@@ -76,32 +86,42 @@ def get_context(self, instance, args, kwargs):
76
86
query .append (body_serialized )
77
87
if query :
78
88
context ["db" ]["statement" ] = "\n \n " .join (query )
89
+
79
90
context ["destination" ] = {
80
91
"address" : instance .host ,
81
92
"service" : {"name" : "elasticsearch" , "resource" : "elasticsearch" , "type" : "db" },
82
93
}
83
- return context
84
94
85
95
86
- class ElasticsearchConnectionInstrumentation ( ElasticSearchConnectionMixin , AbstractInstrumentedModule ):
96
+ class ElasticsearchTransportInstrumentation ( AbstractInstrumentedModule ):
87
97
name = "elasticsearch_connection"
88
98
89
99
instrument_list = [
90
- ("elasticsearch.connection.http_urllib3" , "Urllib3HttpConnection.perform_request" ),
91
- ("elasticsearch.connection.http_requests" , "RequestsHttpConnection.perform_request" ),
100
+ ("elasticsearch.transport" , "Transport.perform_request" ),
92
101
]
93
102
94
103
def call (self , module , method , wrapped , instance , args , kwargs ):
95
- signature = self .get_signature (args , kwargs )
96
- context = self .get_context (instance , args , kwargs )
97
-
98
104
with elasticapm .capture_span (
99
- signature ,
105
+ self . _get_signature ( args , kwargs ) ,
100
106
span_type = "db" ,
101
107
span_subtype = "elasticsearch" ,
102
108
span_action = "query" ,
103
- extra = context ,
109
+ extra = {} ,
104
110
skip_frames = 2 ,
105
111
leaf = True ,
106
- ):
107
- return wrapped (* args , ** kwargs )
112
+ ) as span :
113
+ result_data = wrapped (* args , ** kwargs )
114
+
115
+ try :
116
+ span .context ["db" ]["rows_affected" ] = result_data ["hits" ]["total" ]["value" ]
117
+ except (KeyError , TypeError ):
118
+ pass
119
+
120
+ return result_data
121
+
122
+ def _get_signature (self , args , kwargs ):
123
+ args_len = len (args )
124
+ http_method = args [0 ] if args_len else kwargs .get ("method" )
125
+ http_path = args [1 ] if args_len > 1 else kwargs .get ("url" )
126
+
127
+ return "ES %s %s" % (http_method , http_path )
0 commit comments